feat: cache impl done

This commit is contained in:
Relakkes 2024-06-02 19:57:13 +08:00
parent 6c4116f240
commit 4bba1447f8
16 changed files with 180 additions and 43 deletions

5
cache/__init__.py vendored
View File

@ -1,5 +0,0 @@
# -*- coding: utf-8 -*-
# @Author : relakkes@gmail.com
# @Name : 程序员阿江-Relakkes
# @Time : 2024/6/2 11:05
# @Desc :

13
cache/abs_cache.py vendored
View File

@ -5,10 +5,10 @@
# @Desc : 抽象类
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, List, Optional
class Cache(ABC):
class AbstractCache(ABC):
@abstractmethod
def get(self, key: str) -> Optional[Any]:
@ -31,3 +31,12 @@ class Cache(ABC):
:return:
"""
raise NotImplementedError
@abstractmethod
def keys(self, pattern: str) -> List[str]:
"""
获取所有符合pattern的key
:param pattern: 匹配模式
:return:
"""
raise NotImplementedError

29
cache/cache_factory.py vendored Normal file
View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# @Author : relakkes@gmail.com
# @Name : 程序员阿江-Relakkes
# @Time : 2024/6/2 11:23
# @Desc :
class CacheFactory:
"""
缓存工厂类
"""
@staticmethod
def create_cache(cache_type: str, *args, **kwargs):
"""
创建缓存对象
:param cache_type: 缓存类型
:param args: 参数
:param kwargs: 关键字参数
:return:
"""
if cache_type == 'memory':
from .local_cache import ExpiringLocalCache
return ExpiringLocalCache(*args, **kwargs)
elif cache_type == 'redis':
from .redis_cache import RedisCache
return RedisCache()
else:
raise ValueError(f'Unknown cache type: {cache_type}')

23
cache/local_cache.py vendored
View File

@ -6,12 +6,12 @@
import asyncio
import time
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from abs_cache import Cache
from cache.abs_cache import AbstractCache
class ExpiringLocalCache(Cache):
class ExpiringLocalCache(AbstractCache):
def __init__(self, cron_interval: int = 10):
"""
@ -60,6 +60,21 @@ class ExpiringLocalCache(Cache):
"""
self._cache_container[key] = (value, time.time() + expire_time)
def keys(self, pattern: str) -> List[str]:
"""
获取所有符合pattern的key
:param pattern: 匹配模式
:return:
"""
if pattern == '*':
return list(self._cache_container.keys())
# 本地缓存通配符暂时将*替换为空
if '*' in pattern:
pattern = pattern.replace('*', '')
return [key for key in self._cache_container.keys() if pattern in key]
def _schedule_clear(self):
"""
开启定时清理任务,
@ -93,11 +108,11 @@ class ExpiringLocalCache(Cache):
await asyncio.sleep(self._cron_interval)
if __name__ == '__main__':
cache = ExpiringLocalCache(cron_interval=2)
cache.set('name', '程序员阿江-Relakkes', 3)
print(cache.get('key'))
print(cache.keys("*"))
time.sleep(4)
print(cache.get('key'))
del cache

15
cache/redis_cache.py vendored
View File

@ -3,18 +3,18 @@
# @Name : 程序员阿江-Relakkes
# @Time : 2024/5/29 22:57
# @Desc : RedisCache实现
import os
import pickle
import time
from typing import Any
from typing import Any, List
from abs_cache import Cache
from redis import Redis
from cache.abs_cache import AbstractCache
from config import db_config
class RedisCache(Cache):
class RedisCache(AbstractCache):
def __init__(self) -> None:
# 连接redis, 返回redis客户端
self._redis_client = self._connet_redis()
@ -53,12 +53,19 @@ class RedisCache(Cache):
"""
self._redis_client.set(key, pickle.dumps(value), ex=expire_time)
def keys(self, pattern: str) -> List[str]:
"""
获取所有符合pattern的key
"""
return [key.decode() for key in self._redis_client.keys(pattern)]
if __name__ == '__main__':
redis_cache = RedisCache()
# basic usage
redis_cache.set("name", "程序员阿江-Relakkes", 1)
print(redis_cache.get("name")) # Relakkes
print(redis_cache.keys("*")) # ['name']
time.sleep(2)
print(redis_cache.get("name")) # None

View File

@ -14,3 +14,7 @@ REDIS_DB_HOST = "127.0.0.1" # your redis host
REDIS_DB_PWD = os.getenv("REDIS_DB_PWD", "123456") # your redis password
REDIS_DB_PORT = os.getenv("REDIS_DB_PORT", 6379) # your redis port
REDIS_DB_NUM = os.getenv("REDIS_DB_NUM", 0) # your redis db num
# cache type
CACHE_TYPE_REDIS = "redis"
CACHE_TYPE_MEMORY = "memory"

View File

@ -3,7 +3,6 @@ import functools
import sys
from typing import Optional
import redis
from playwright.async_api import BrowserContext, Page
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt,
@ -11,6 +10,7 @@ from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt,
import config
from base.base_crawler import AbstractLogin
from cache.cache_factory import CacheFactory
from tools import utils
@ -129,13 +129,13 @@ class DouYinLogin(AbstractLogin):
# 检查是否有滑动验证码
await self.check_page_display_slider(move_step=10, slider_level="easy")
redis_obj = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD)
cache_client = CacheFactory.create_cache(config.CACHE_TYPE_MEMORY)
max_get_sms_code_time = 60 * 2 # 最长获取验证码的时间为2分钟
while max_get_sms_code_time > 0:
utils.logger.info(f"[DouYinLogin.login_by_mobile] get douyin sms code from redis remaining time {max_get_sms_code_time}s ...")
await asyncio.sleep(1)
sms_code_key = f"dy_{self.login_phone}"
sms_code_value = redis_obj.get(sms_code_key)
sms_code_value = cache_client.get(sms_code_key)
if not sms_code_value:
max_get_sms_code_time -= 1
continue

View File

@ -3,12 +3,10 @@ import functools
import sys
from typing import Optional
import redis
from playwright.async_api import BrowserContext, Page
from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt,
wait_fixed)
import config
from base.base_crawler import AbstractLogin
from tools import utils

View File

@ -3,13 +3,13 @@ import functools
import sys
from typing import Optional
import redis
from playwright.async_api import BrowserContext, Page
from tenacity import (RetryError, retry, retry_if_result, stop_after_attempt,
wait_fixed)
import config
from base.base_crawler import AbstractLogin
from cache.cache_factory import CacheFactory
from tools import utils
@ -89,14 +89,14 @@ class XiaoHongShuLogin(AbstractLogin):
await send_btn_ele.click() # 点击发送验证码
sms_code_input_ele = await login_container_ele.query_selector("label.auth-code > input")
submit_btn_ele = await login_container_ele.query_selector("div.input-container > button")
redis_obj = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD)
cache_client = CacheFactory.create_cache(config.CACHE_TYPE_MEMORY)
max_get_sms_code_time = 60 * 2 # 最长获取验证码的时间为2分钟
no_logged_in_session = ""
while max_get_sms_code_time > 0:
utils.logger.info(f"[XiaoHongShuLogin.login_by_mobile] get sms code from redis remaining time {max_get_sms_code_time}s ...")
await asyncio.sleep(1)
sms_code_key = f"xhs_{self.login_phone}"
sms_code_value = redis_obj.get(sms_code_key)
sms_code_value = cache_client.get(sms_code_key)
if not sms_code_value:
max_get_sms_code_time -= 1
continue

View File

@ -2,14 +2,14 @@
# @Author : relakkes@gmail.com
# @Time : 2023/12/2 11:18
# @Desc : 爬虫 IP 获取实现
# @Url : 现在实现了极速HTTP的接口官网地址https://www.jisuhttp.com/?pl=mAKphQ&plan=ZY&kd=Yang
# @Url : 快代理HTTP实现官方文档https://www.kuaidaili.com/?ref=ldwkjqipvz6c
import json
from abc import ABC, abstractmethod
from typing import Dict, List
import redis
from typing import List
import config
from cache.abs_cache import AbstractCache
from cache.cache_factory import CacheFactory
from tools import utils
from .types import IpInfoModel
@ -30,9 +30,9 @@ class ProxyProvider(ABC):
pass
class RedisDbIpCache:
class IpCache:
def __init__(self):
self.redis_client = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD)
self.cache_client: AbstractCache = CacheFactory.create_cache(cache_type=config.CACHE_TYPE_MEMORY)
def set_ip(self, ip_key: str, ip_value_info: str, ex: int):
"""
@ -42,7 +42,7 @@ class RedisDbIpCache:
:param ex:
:return:
"""
self.redis_client.set(name=ip_key, value=ip_value_info, ex=ex)
self.cache_client.set(key=ip_key, value=ip_value_info, expire_time=ex)
def load_all_ip(self, proxy_brand_name: str) -> List[IpInfoModel]:
"""
@ -51,13 +51,13 @@ class RedisDbIpCache:
:return:
"""
all_ip_list: List[IpInfoModel] = []
all_ip_keys: List[bytes] = self.redis_client.keys(pattern=f"{proxy_brand_name}_*")
all_ip_keys: List[str] = self.cache_client.keys(pattern=f"{proxy_brand_name}_*")
try:
for ip_key in all_ip_keys:
ip_value = self.redis_client.get(ip_key)
ip_value = self.cache_client.get(ip_key)
if not ip_value:
continue
all_ip_list.append(IpInfoModel(**json.loads(ip_value)))
except Exception as e:
utils.logger.error("[RedisDbIpCache.load_all_ip] get ip err from redis db", e)
utils.logger.error("[IpCache.load_all_ip] get ip err from redis db", e)
return all_ip_list

View File

@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
# @Author : relakkes@gmail.com
# @Time : 2024/4/5 09:32
# @Desc : 极速HTTP代理提供类实现,官网地址https://www.jisuhttp.com?pl=zG3Jna
# @Desc : 已废弃倒闭了极速HTTP 代理IP实现
import os
from typing import Dict, List
from urllib.parse import urlencode
import httpx
from proxy import IpGetError, ProxyProvider, RedisDbIpCache
from proxy import IpCache, IpGetError, ProxyProvider
from proxy.types import IpInfoModel
from tools import utils
@ -31,7 +31,7 @@ class JiSuHttpProxy(ProxyProvider):
"pw": "1", # 是否使用账密验证, 10否表示白名单验证默认为0
"se": "1", # 返回JSON格式时是否显示IP过期时间 1显示0不显示默认为0
}
self.ip_cache = RedisDbIpCache()
self.ip_cache = IpCache()
async def get_proxies(self, num: int) -> List[IpInfoModel]:
"""

View File

@ -9,7 +9,7 @@ from typing import Dict, List
import httpx
from pydantic import BaseModel, Field
from proxy import IpGetError, IpInfoModel, ProxyProvider, RedisDbIpCache
from proxy import IpCache, IpInfoModel, ProxyProvider
from proxy.types import ProviderNameEnum
from tools import utils
@ -58,7 +58,7 @@ class KuaiDaiLiProxy(ProxyProvider):
self.api_base = "https://dps.kdlapi.com/"
self.secret_id = kdl_secret_id
self.signature = kdl_signature
self.ip_cache = RedisDbIpCache()
self.ip_cache = IpCache()
self.proxy_brand_name = ProviderNameEnum.KUAI_DAILI_PROVIDER.value
self.params = {
"secret_id": self.secret_id,

View File

@ -1,17 +1,18 @@
import re
from typing import List
import redis
import uvicorn
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
import config
from cache.abs_cache import AbstractCache
from cache.cache_factory import CacheFactory
from tools import utils
app = FastAPI()
redis_client = redis.Redis(host=config.REDIS_DB_HOST, password=config.REDIS_DB_PWD)
cache_client : AbstractCache = CacheFactory.create_cache(cache_type=config.CACHE_TYPE_MEMORY)
class SmsNotification(BaseModel):
@ -53,7 +54,7 @@ def receive_sms_notification(sms: SmsNotification):
if sms_code:
# Save the verification code in Redis and set the expiration time to 3 minutes.
key = f"{sms.platform}_{sms.current_number}"
redis_client.set(key, sms_code, ex=60 * 3)
cache_client.set(key, sms_code, expire_time=60 * 3)
return {"status": "ok"}

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# @Author : relakkes@gmail.com
# @Name : 程序员阿江-Relakkes
# @Time : 2024/6/2 10:35
# @Desc :
import time
import unittest
from cache.local_cache import ExpiringLocalCache
class TestExpiringLocalCache(unittest.TestCase):
def setUp(self):
self.cache = ExpiringLocalCache(cron_interval=10)
def test_set_and_get(self):
self.cache.set('key', 'value', 10)
self.assertEqual(self.cache.get('key'), 'value')
def test_expired_key(self):
self.cache.set('key', 'value', 1)
time.sleep(2) # wait for the key to expire
self.assertIsNone(self.cache.get('key'))
def test_clear(self):
# 设置两个键值对过期时间为11秒
self.cache.set('key', 'value', 11)
# 睡眠12秒让cache类的定时任务执行一次
time.sleep(12)
self.assertIsNone(self.cache.get('key'))
def tearDown(self):
del self.cache
if __name__ == '__main__':
unittest.main()

40
test/test_redis_cache.py Normal file
View File

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# @Author : relakkes@gmail.com
# @Name : 程序员阿江-Relakkes
# @Time : 2024/6/2 19:54
# @Desc :
import time
import unittest
from cache.redis_cache import RedisCache
class TestRedisCache(unittest.TestCase):
def setUp(self):
self.redis_cache = RedisCache()
def test_set_and_get(self):
self.redis_cache.set('key', 'value', 10)
self.assertEqual(self.redis_cache.get('key'), 'value')
def test_expired_key(self):
self.redis_cache.set('key', 'value', 1)
time.sleep(2) # wait for the key to expire
self.assertIsNone(self.redis_cache.get('key'))
def test_keys(self):
self.redis_cache.set('key1', 'value1', 10)
self.redis_cache.set('key2', 'value2', 10)
keys = self.redis_cache.keys('*')
self.assertIn('key1', keys)
self.assertIn('key2', keys)
def tearDown(self):
# self.redis_cache._redis_client.flushdb() # 清空redis数据库
pass
if __name__ == '__main__':
unittest.main()

View File

@ -39,7 +39,7 @@ class Slide:
"q=0.8,application/signed-exchange;v=b3;q=0.9",
"Accept-Encoding": "gzip, deflate, br",
"Accept-Language": "zh-CN,zh;q=0.9,en-GB;q=0.8,en;q=0.7,ja;q=0.6",
"Cache-Control": "max-age=0",
"AbstractCache-Control": "max-age=0",
"Connection": "keep-alive",
"Host": urlparse(img).hostname,
"Upgrade-Insecure-Requests": "1",