diff --git a/cache/__init__.py b/cache/__init__.py index 2e88155..e69de29 100644 --- a/cache/__init__.py +++ b/cache/__init__.py @@ -1,5 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author : relakkes@gmail.com -# @Name : 程序员阿江-Relakkes -# @Time : 2024/6/2 11:05 -# @Desc : diff --git a/cache/abs_cache.py b/cache/abs_cache.py index 5558d82..5c592ee 100644 --- a/cache/abs_cache.py +++ b/cache/abs_cache.py @@ -5,10 +5,10 @@ # @Desc : 抽象类 from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, List, Optional -class Cache(ABC): +class AbstractCache(ABC): @abstractmethod def get(self, key: str) -> Optional[Any]: @@ -31,3 +31,12 @@ class Cache(ABC): :return: """ raise NotImplementedError + + @abstractmethod + def keys(self, pattern: str) -> List[str]: + """ + 获取所有符合pattern的key + :param pattern: 匹配模式 + :return: + """ + raise NotImplementedError diff --git a/cache/cache_factory.py b/cache/cache_factory.py new file mode 100644 index 0000000..bb1ec4d --- /dev/null +++ b/cache/cache_factory.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# @Author : relakkes@gmail.com +# @Name : 程序员阿江-Relakkes +# @Time : 2024/6/2 11:23 +# @Desc : + + +class CacheFactory: + """ + 缓存工厂类 + """ + + @staticmethod + def create_cache(cache_type: str, *args, **kwargs): + """ + 创建缓存对象 + :param cache_type: 缓存类型 + :param args: 参数 + :param kwargs: 关键字参数 + :return: + """ + if cache_type == 'memory': + from .local_cache import ExpiringLocalCache + return ExpiringLocalCache(*args, **kwargs) + elif cache_type == 'redis': + from .redis_cache import RedisCache + return RedisCache() + else: + raise ValueError(f'Unknown cache type: {cache_type}') diff --git a/cache/local_cache.py b/cache/local_cache.py index aebf56e..d56be50 100644 --- a/cache/local_cache.py +++ b/cache/local_cache.py @@ -6,12 +6,12 @@ import asyncio import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple -from abs_cache import Cache +from cache.abs_cache import AbstractCache -class ExpiringLocalCache(Cache): +class ExpiringLocalCache(AbstractCache): def __init__(self, cron_interval: int = 10): """ @@ -60,6 +60,21 @@ class ExpiringLocalCache(Cache): """ self._cache_container[key] = (value, time.time() + expire_time) + def keys(self, pattern: str) -> List[str]: + """ + 获取所有符合pattern的key + :param pattern: 匹配模式 + :return: + """ + if pattern == '*': + return list(self._cache_container.keys()) + + # 本地缓存通配符暂时将*替换为空 + if '*' in pattern: + pattern = pattern.replace('*', '') + + return [key for key in self._cache_container.keys() if pattern in key] + def _schedule_clear(self): """ 开启定时清理任务, @@ -93,13 +108,13 @@ class ExpiringLocalCache(Cache): await asyncio.sleep(self._cron_interval) - if __name__ == '__main__': cache = ExpiringLocalCache(cron_interval=2) cache.set('name', '程序员阿江-Relakkes', 3) print(cache.get('key')) + print(cache.keys("*")) time.sleep(4) print(cache.get('key')) del cache time.sleep(1) - print("done") \ No newline at end of file + print("done") diff --git a/cache/redis_cache.py b/cache/redis_cache.py index 3fcbc40..4c9df0b 100644 --- a/cache/redis_cache.py +++ b/cache/redis_cache.py @@ -3,18 +3,18 @@ # @Name : 程序员阿江-Relakkes # @Time : 2024/5/29 22:57 # @Desc : RedisCache实现 -import os import pickle import time -from typing import Any +from typing import Any, List -from abs_cache import Cache from redis import Redis +from cache.abs_cache import AbstractCache from config import db_config -class RedisCache(Cache): +class RedisCache(AbstractCache): + def __init__(self) -> None: # 连接redis, 返回redis客户端 self._redis_client = self._connet_redis() @@ -53,12 +53,19 @@ class RedisCache(Cache): """ self._redis_client.set(key, pickle.dumps(value), ex=expire_time) + def keys(self, pattern: str) -> List[str]: + """ + 获取所有符合pattern的key + """ + return [key.decode() for key in self._redis_client.keys(pattern)] + if __name__ == '__main__': redis_cache = RedisCache() # basic usage redis_cache.set("name", "程序员阿江-Relakkes", 1) print(redis_cache.get("name")) # Relakkes + print(redis_cache.keys("*")) # ['name'] time.sleep(2) print(redis_cache.get("name")) # None diff --git a/config/db_config.py b/config/db_config.py index 1ee1996..399220e 100644 --- a/config/db_config.py +++ b/config/db_config.py @@ -14,3 +14,7 @@ REDIS_DB_HOST = "127.0.0.1" # your redis host REDIS_DB_PWD = os.getenv("REDIS_DB_PWD", "123456") # your redis password REDIS_DB_PORT = os.getenv("REDIS_DB_PORT", 6379) # your redis port REDIS_DB_NUM = os.getenv("REDIS_DB_NUM", 0) # your redis db num + +# cache type +CACHE_TYPE_REDIS = "redis" +CACHE_TYPE_MEMORY = "memory" \ No newline at end of file diff --git a/media_platform/douyin/login.py b/media_platform/douyin/login.py index 0c72907..f0b3c23 100644 --- a/media_platform/douyin/login.py +++ b/media_platform/douyin/login.py @@ -3,7 +3,6 @@ import functools import sys from typing import Optional -import redis from playwright.async_api import BrowserContext, Page from playwright.async_api import TimeoutError as PlaywrightTimeoutError from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt, @@ -11,6 +10,7 @@ from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt, import config from base.base_crawler import AbstractLogin +from cache.cache_factory import CacheFactory from tools import utils @@ -129,13 +129,13 @@ class DouYinLogin(AbstractLogin): # 检查是否有滑动验证码 await self.check_page_display_slider(move_step=10, slider_level="easy") - redis_obj = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD) + cache_client = CacheFactory.create_cache(config.CACHE_TYPE_MEMORY) max_get_sms_code_time = 60 * 2 # 最长获取验证码的时间为2分钟 while max_get_sms_code_time > 0: utils.logger.info(f"[DouYinLogin.login_by_mobile] get douyin sms code from redis remaining time {max_get_sms_code_time}s ...") await asyncio.sleep(1) sms_code_key = f"dy_{self.login_phone}" - sms_code_value = redis_obj.get(sms_code_key) + sms_code_value = cache_client.get(sms_code_key) if not sms_code_value: max_get_sms_code_time -= 1 continue diff --git a/media_platform/kuaishou/login.py b/media_platform/kuaishou/login.py index 7833139..54a9e38 100644 --- a/media_platform/kuaishou/login.py +++ b/media_platform/kuaishou/login.py @@ -3,12 +3,10 @@ import functools import sys from typing import Optional -import redis from playwright.async_api import BrowserContext, Page from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt, wait_fixed) -import config from base.base_crawler import AbstractLogin from tools import utils diff --git a/media_platform/xhs/login.py b/media_platform/xhs/login.py index 0cff85f..07c0ba2 100644 --- a/media_platform/xhs/login.py +++ b/media_platform/xhs/login.py @@ -3,13 +3,13 @@ import functools import sys from typing import Optional -import redis from playwright.async_api import BrowserContext, Page from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt, wait_fixed) import config from base.base_crawler import AbstractLogin +from cache.cache_factory import CacheFactory from tools import utils @@ -89,14 +89,14 @@ class XiaoHongShuLogin(AbstractLogin): await send_btn_ele.click() # 点击发送验证码 sms_code_input_ele = await login_container_ele.query_selector("label.auth-code > input") submit_btn_ele = await login_container_ele.query_selector("div.input-container > button") - redis_obj = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD) + cache_client = CacheFactory.create_cache(config.CACHE_TYPE_MEMORY) max_get_sms_code_time = 60 * 2 # 最长获取验证码的时间为2分钟 no_logged_in_session = "" while max_get_sms_code_time > 0: utils.logger.info(f"[XiaoHongShuLogin.login_by_mobile] get sms code from redis remaining time {max_get_sms_code_time}s ...") await asyncio.sleep(1) sms_code_key = f"xhs_{self.login_phone}" - sms_code_value = redis_obj.get(sms_code_key) + sms_code_value = cache_client.get(sms_code_key) if not sms_code_value: max_get_sms_code_time -= 1 continue diff --git a/proxy/base_proxy.py b/proxy/base_proxy.py index 40ed75f..9294008 100644 --- a/proxy/base_proxy.py +++ b/proxy/base_proxy.py @@ -2,14 +2,14 @@ # @Author : relakkes@gmail.com # @Time : 2023/12/2 11:18 # @Desc : 爬虫 IP 获取实现 -# @Url : 现在实现了极速HTTP的接口,官网地址:https://www.jisuhttp.com/?pl=mAKphQ&plan=ZY&kd=Yang +# @Url : 快代理HTTP实现,官方文档:https://www.kuaidaili.com/?ref=ldwkjqipvz6c import json from abc import ABC, abstractmethod -from typing import Dict, List - -import redis +from typing import List import config +from cache.abs_cache import AbstractCache +from cache.cache_factory import CacheFactory from tools import utils from .types import IpInfoModel @@ -30,9 +30,9 @@ class ProxyProvider(ABC): pass -class RedisDbIpCache: +class IpCache: def __init__(self): - self.redis_client = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD) + self.cache_client: AbstractCache = CacheFactory.create_cache(cache_type=config.CACHE_TYPE_MEMORY) def set_ip(self, ip_key: str, ip_value_info: str, ex: int): """ @@ -42,7 +42,7 @@ class RedisDbIpCache: :param ex: :return: """ - self.redis_client.set(name=ip_key, value=ip_value_info, ex=ex) + self.cache_client.set(key=ip_key, value=ip_value_info, expire_time=ex) def load_all_ip(self, proxy_brand_name: str) -> List[IpInfoModel]: """ @@ -51,13 +51,13 @@ class RedisDbIpCache: :return: """ all_ip_list: List[IpInfoModel] = [] - all_ip_keys: List[bytes] = self.redis_client.keys(pattern=f"{proxy_brand_name}_*") + all_ip_keys: List[str] = self.cache_client.keys(pattern=f"{proxy_brand_name}_*") try: for ip_key in all_ip_keys: - ip_value = self.redis_client.get(ip_key) + ip_value = self.cache_client.get(ip_key) if not ip_value: continue all_ip_list.append(IpInfoModel(**json.loads(ip_value))) except Exception as e: - utils.logger.error("[RedisDbIpCache.load_all_ip] get ip err from redis db", e) + utils.logger.error("[IpCache.load_all_ip] get ip err from redis db", e) return all_ip_list diff --git a/proxy/providers/jishu_http_proxy.py b/proxy/providers/jishu_http_proxy.py index fd5b3b3..b4787ab 100644 --- a/proxy/providers/jishu_http_proxy.py +++ b/proxy/providers/jishu_http_proxy.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- # @Author : relakkes@gmail.com # @Time : 2024/4/5 09:32 -# @Desc : 极速HTTP代理提供类实现,官网地址:https://www.jisuhttp.com?pl=zG3Jna +# @Desc : 已废弃!!!!!倒闭了!!!极速HTTP 代理IP实现 import os from typing import Dict, List from urllib.parse import urlencode import httpx -from proxy import IpGetError, ProxyProvider, RedisDbIpCache +from proxy import IpCache, IpGetError, ProxyProvider from proxy.types import IpInfoModel from tools import utils @@ -31,7 +31,7 @@ class JiSuHttpProxy(ProxyProvider): "pw": "1", # 是否使用账密验证, 1:是,0:否,否表示白名单验证;默认为0 "se": "1", # 返回JSON格式时是否显示IP过期时间, 1:显示,0:不显示;默认为0 } - self.ip_cache = RedisDbIpCache() + self.ip_cache = IpCache() async def get_proxies(self, num: int) -> List[IpInfoModel]: """ diff --git a/proxy/providers/kuaidl_proxy.py b/proxy/providers/kuaidl_proxy.py index c5f46c7..1b39d43 100644 --- a/proxy/providers/kuaidl_proxy.py +++ b/proxy/providers/kuaidl_proxy.py @@ -9,7 +9,7 @@ from typing import Dict, List import httpx from pydantic import BaseModel, Field -from proxy import IpGetError, IpInfoModel, ProxyProvider, RedisDbIpCache +from proxy import IpCache, IpInfoModel, ProxyProvider from proxy.types import ProviderNameEnum from tools import utils @@ -58,7 +58,7 @@ class KuaiDaiLiProxy(ProxyProvider): self.api_base = "https://dps.kdlapi.com/" self.secret_id = kdl_secret_id self.signature = kdl_signature - self.ip_cache = RedisDbIpCache() + self.ip_cache = IpCache() self.proxy_brand_name = ProviderNameEnum.KUAI_DAILI_PROVIDER.value self.params = { "secret_id": self.secret_id, diff --git a/recv_sms.py b/recv_sms.py index 300db92..fbe5eb8 100644 --- a/recv_sms.py +++ b/recv_sms.py @@ -1,17 +1,18 @@ import re from typing import List -import redis import uvicorn from fastapi import FastAPI, HTTPException, status from pydantic import BaseModel import config +from cache.abs_cache import AbstractCache +from cache.cache_factory import CacheFactory from tools import utils app = FastAPI() -redis_client = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD) +cache_client : AbstractCache = CacheFactory.create_cache(cache_type=config.CACHE_TYPE_MEMORY) class SmsNotification(BaseModel): @@ -53,7 +54,7 @@ def receive_sms_notification(sms: SmsNotification): if sms_code: # Save the verification code in Redis and set the expiration time to 3 minutes. key = f"{sms.platform}_{sms.current_number}" - redis_client.set(key, sms_code, ex=60 * 3) + cache_client.set(key, sms_code, expire_time=60 * 3) return {"status": "ok"} diff --git a/test/test_expiring_local_cache.py b/test/test_expiring_local_cache.py new file mode 100644 index 0000000..4bb17f3 --- /dev/null +++ b/test/test_expiring_local_cache.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Author : relakkes@gmail.com +# @Name : 程序员阿江-Relakkes +# @Time : 2024/6/2 10:35 +# @Desc : + +import time +import unittest + +from cache.local_cache import ExpiringLocalCache + + +class TestExpiringLocalCache(unittest.TestCase): + + def setUp(self): + self.cache = ExpiringLocalCache(cron_interval=10) + + def test_set_and_get(self): + self.cache.set('key', 'value', 10) + self.assertEqual(self.cache.get('key'), 'value') + + def test_expired_key(self): + self.cache.set('key', 'value', 1) + time.sleep(2) # wait for the key to expire + self.assertIsNone(self.cache.get('key')) + + def test_clear(self): + # 设置两个键值对,过期时间为11秒 + self.cache.set('key', 'value', 11) + # 睡眠12秒,让cache类的定时任务执行一次 + time.sleep(12) + self.assertIsNone(self.cache.get('key')) + + def tearDown(self): + del self.cache + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_redis_cache.py b/test/test_redis_cache.py new file mode 100644 index 0000000..29b5f37 --- /dev/null +++ b/test/test_redis_cache.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# @Author : relakkes@gmail.com +# @Name : 程序员阿江-Relakkes +# @Time : 2024/6/2 19:54 +# @Desc : + +import time +import unittest + +from cache.redis_cache import RedisCache + + +class TestRedisCache(unittest.TestCase): + + def setUp(self): + self.redis_cache = RedisCache() + + def test_set_and_get(self): + self.redis_cache.set('key', 'value', 10) + self.assertEqual(self.redis_cache.get('key'), 'value') + + def test_expired_key(self): + self.redis_cache.set('key', 'value', 1) + time.sleep(2) # wait for the key to expire + self.assertIsNone(self.redis_cache.get('key')) + + def test_keys(self): + self.redis_cache.set('key1', 'value1', 10) + self.redis_cache.set('key2', 'value2', 10) + keys = self.redis_cache.keys('*') + self.assertIn('key1', keys) + self.assertIn('key2', keys) + + def tearDown(self): + # self.redis_cache._redis_client.flushdb() # 清空redis数据库 + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/slider_util.py b/tools/slider_util.py index 93bc9d2..f3a123e 100644 --- a/tools/slider_util.py +++ b/tools/slider_util.py @@ -39,7 +39,7 @@ class Slide: "q=0.8,application/signed-exchange;v=b3;q=0.9", "Accept-Encoding": "gzip, deflate, br", "Accept-Language": "zh-CN,zh;q=0.9,en-GB;q=0.8,en;q=0.7,ja;q=0.6", - "Cache-Control": "max-age=0", + "AbstractCache-Control": "max-age=0", "Connection": "keep-alive", "Host": urlparse(img).hostname, "Upgrade-Insecure-Requests": "1",