diff --git a/DB/__init__.py b/DB/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/DB/db_config.py b/DB/db_config.py new file mode 100644 index 0000000..748cda3 --- /dev/null +++ b/DB/db_config.py @@ -0,0 +1,22 @@ +mysql = { + "url": "127.0.0.1", + "port": "3306", + "username": "root", + "password": "qeadzc123", + "database": "qnloft_hospital", +} +# 基本数据 +stock_info_db = "stock_info.db" +# 板块相关数据 +stock_sector_db = "stock_sector.db" +# 个股日线数据 +stock_daily_db = "stock_daily.db" +# 分钟线 +stock_daily_freq_db = "stock_daily_freq.db" +# redis_info = ['10.10.XXX', XXX] +# mongo_info = ['10.10.XXX', XXX] +# es_info = ['10.10.XXX', XXX] +# sqlserver_info = ['10.10.XXX:XXX', 'XXX', 'XXX', 'XXX'] +# db2_info = ['10.10.XXX', XXX, 'XXX', 'XXX'] +# postgre_info = ['XXX', 'XXX', 'XXX', '10.10.XXX', XXX] +# ck_info = ['10.10.XXX', XXX] diff --git a/DB/db_main.py b/DB/db_main.py new file mode 100644 index 0000000..7100df1 --- /dev/null +++ b/DB/db_main.py @@ -0,0 +1,271 @@ +import json +import traceback + +import pandas +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker, scoped_session, aliased, declarative_base, clear_mappers +from sqlalchemy import * + +from DB import db_config as config + + +class DbMain: + + def __init__(self): + # clear_mappers() + self.inspector = None + self.session = None + self.engine = None + + def get_session(self): + # 绑定引擎 + session_factory = sessionmaker(bind=self.engine) + # 创建数据库链接池,直接使用session即可为当前线程拿出一个链接对象conn + # 内部会采用threading.local进行隔离 + Session = scoped_session(session_factory) + return Session() + + # ===================== insert or update方法============================= + def insert_all_entry(self, entries): + try: + self.create_table(entries) + self.session.add_all(entries) + self.session.commit() + except Exception as e: + print(e) + finally: + self.session.close() + + def insert_entry(self, entry): + try: + self.create_table(entry) + self.session.add(entry) + self.session.commit() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误") + finally: + self.session.close() + + def insert_or_update(self, entry, query_conditions): + """ + insert_or_update 的是用需要在 对象中新增to_dict方法,将需要更新的字段转成字典 + :param entry: + :param query_conditions: + :return: + """ + try: + self.create_table(entry) + conditions = " AND ".join(f"{key} = '{value}'" for key, value in query_conditions.items()) + select_sql = text(f"select count(1) from {entry.__tablename__} where 1=1 and {conditions}") + result = self.session.execute(select_sql) + # 如果查询有结果,则执行update操作 + if result.scalar() > 0: + if hasattr(entry, 'to_dict'): + formatted_attributes = ", ".join([f'{attr} = "{value}"' for attr, value in entry.to_dict().items()]) + update_sql = text(f"UPDATE {entry.__tablename__} SET {formatted_attributes} WHERE {conditions}") + self.session.execute(update_sql) + else: + raise Exception("对象 不包含 to_dict 方法,请添加!") + else: + # 执行新增错做 + self.insert_entry(entry) + self.session.commit() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.expire_all() + self.session.close() + + def create_table(self, entries): + if isinstance(entries, list): + table_name = entries[0].__tablename__ + else: + table_name = entries.__tablename__ + # 检查表是否存在,如果不存在则创建 + if not self.inspector.has_table(table_name): + if isinstance(entries, list): + entries[0].metadata.create_all(self.engine) + else: + entries.metadata.create_all(self.engine) + + def pandas_insert(self, data=pandas, table_name=""): + """ + 新增数据操作,类型是pandas + :param data: + :param table_name: + :return: + """ + try: + data.to_sql(table_name, con=self.engine.connect(), if_exists='append', index=True, index_label='id') + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + self.engine.dispose() + + # ===================== select 方法============================= + def query_by_id(self, model, db_id): + try: + return self.session.query(model).filter_by(id=db_id).all() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + def pandas_query_by_model(self, model, order_col=None, page_number=None, page_size=None): + """ + 使用pandas的sql引擎执行 + :param page_size:每页的记录数 + :param page_number:第N页 + :param model: + :param order_col: + :return: + """ + try: + # 判断表是否存在 + if self.has_table(model): + query = self.session.query(model) + if order_col is not None: + query = query.order_by(order_col) + if page_number is not None and page_size is not None: + offset = (page_number - 1) * page_size + query = query.offset(offset).limit(page_size) + return pandas.read_sql_query(query.statement, self.engine.connect()) + return pandas.DataFrame() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + self.engine.dispose() + + def pandas_query_by_sql(self, stmt=""): + """ + 使用pandas的sql引擎执行 + :param stmt: + :return: + """ + try: + return pandas.read_sql_query(sql=stmt, con=self.engine.connect()) + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + self.engine.dispose() + + def pandas_query_by_condition(self, model, query_condition): + try: + # 当需要根据多个条件进行查询操作时 + # query_condition = and_( + # StockDaily.trade_date == '20230823', + # StockDaily.symbol == 'ABC' + # ) + query = self.session.query(model).filter(query_condition).order_by(model.id) + return self.pandas_query_by_sql(stmt=query.statement).reset_index() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + # ===================== delete 方法============================= + def delete_by_id(self, model, db_id): + try: + # 使用 delete() 方法删除符合条件的记录 + self.session.query(model).filter_by(id=db_id).delete() + self.session.commit() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + def delete_by_condition(self, model, delete_condition): + try: + # 使用 delete() 方法删除符合条件的记录 + # 定义要删除的记录的条件 + # 例如,假设你要删除 trade_date 为 '20230823' 的记录 + # delete_condition = StockDaily.trade_date == '20230823' + # 当需要根据多个条件进行删除操作时 + # delete_condition = and_( + # StockDaily.trade_date == '20230823', + # StockDaily.symbol == 'ABC' + # ) + self.session.query(model).filter(delete_condition).delete() + self.session.commit() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + def delete_all_table(self, model): # 清空表数据 + try: + self.session.query(model).delete() + self.session.commit() + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + # ===================== 其它 方法============================= + def has_table(self, entries): + if isinstance(entries, list): + table_name = entries[0].__tablename__ + else: + table_name = entries.__tablename__ + # 检查表是否存在,如果不存在则创建 + return self.inspector.has_table(table_name) + + def execute_sql(self, s): + try: + sql_text = text(s) + return self.session.execute(sql_text) + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + def execute_sql_to_pandas(self, s): + try: + sql_text = text(s) + res = self.session.execute(sql_text) + return pandas.DataFrame(res.fetchall(), columns=res.keys()) + except Exception as e: + trace = traceback.extract_tb(e.__traceback__) + for filename, lineno, funcname, source in trace: + print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}" + f"错误内容:{traceback.format_exc()}") + finally: + self.session.close() + + def close(self): + self.session.close() diff --git a/DB/model/ConceptSector.py b/DB/model/ConceptSector.py new file mode 100644 index 0000000..7f46f5e --- /dev/null +++ b/DB/model/ConceptSector.py @@ -0,0 +1,20 @@ +from sqlalchemy import Column, Integer, String, Index +from sqlalchemy.orm import declarative_base + +''' +概念板块数据表 +''' +class ConceptSector(declarative_base()): + __tablename__ = 'stock_concept_sector' + + id = Column(Integer, primary_key=True, autoincrement=True) + trade_date = Column(String, nullable=False) + sector_name = Column(String) + sector_code = Column(String) + pct_change = Column(String) + total_market = Column(String) + rising_count = Column(String) + falling_count = Column(String) + new_price = Column(String) + leading_stock = Column(String) + leading_stock_pct_change = Column(String) diff --git a/DB/model/IndustrySector.py b/DB/model/IndustrySector.py new file mode 100644 index 0000000..5b28b82 --- /dev/null +++ b/DB/model/IndustrySector.py @@ -0,0 +1,27 @@ +# 定义表格映射类 +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import declarative_base + +''' +行业板块数据表 +''' +class IndustrySector(declarative_base()): + __tablename__ = 'stock_industry_sector' + + id = Column(Integer, primary_key=True, autoincrement=True) + trade_date = Column(String, nullable=False) + sector_name = Column(String) + pct_change = Column(String) + total_volume = Column(String) + total_turnover = Column(String) + net_inflows = Column(String) + rising_count = Column(String) + falling_count = Column(String) + average_price = Column(String) + leading_stock = Column(String) + leading_stock_latest_price = Column(String) + leading_stock_pct_change = Column(String) + +# 创建索引 +# trade_date_index = Index('idx_trade_date', IndustrySector.trade_date) +# sector_name_index = Index('idx_sector_name', IndustrySector.sector_name) diff --git a/DB/model/StockBasic.py b/DB/model/StockBasic.py new file mode 100644 index 0000000..574ec0d --- /dev/null +++ b/DB/model/StockBasic.py @@ -0,0 +1,42 @@ +# 定义表格映射类 +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import declarative_base + + +# 创建 StockBasic 映射类 +class StockBasic(declarative_base()): + __tablename__ = 'stock_basic' + + id = Column(Integer, primary_key=True, autoincrement=True) + ts_code = Column(String, nullable=False, unique=True) + symbol = Column(String, nullable=False) + name = Column(String) + comp_name = Column(String) + comp_name_en = Column(String) + isin_code = Column(String) + exchange = Column(String) + list_board = Column(String) + list_date = Column(String) + delist_date = Column(String) + crncy_code = Column(String) + pinyin = Column(String) + list_board_name = Column(String) + is_shsc = Column(String) + comp_code = Column(String) + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __repr__(self): + return f"" + + def to_dict(self): + # 定义要保留的属性列表 + allowed_attributes = ['delist_date', 'is_shsc', 'comp_name', 'comp_name_en'] + + # 创建字典并添加需要保留的属性 + obj_dict = {} + for attr in allowed_attributes: + obj_dict[attr] = getattr(self, attr) + + return obj_dict diff --git a/DB/model/StockByConceptSector.py b/DB/model/StockByConceptSector.py new file mode 100644 index 0000000..0523649 --- /dev/null +++ b/DB/model/StockByConceptSector.py @@ -0,0 +1,15 @@ +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import declarative_base + +''' +概念板块个股 +''' +class StockByConceptSector(declarative_base()): + __tablename__ = 'stock_by_concept_sector' + + id = Column(Integer, primary_key=True, autoincrement=True) + update_date = Column(String, nullable=False) + sector_name = Column(String) + sector_code = Column(String) + stock_name = Column(String) + stock_code = Column(String) diff --git a/DB/model/StockByIndustrySector.py b/DB/model/StockByIndustrySector.py new file mode 100644 index 0000000..5484611 --- /dev/null +++ b/DB/model/StockByIndustrySector.py @@ -0,0 +1,15 @@ +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import declarative_base + +''' +行业板块个股 +''' +class StockByIndustrySector(declarative_base()): + __tablename__ = 'stock_by_industry_sector' + + id = Column(Integer, primary_key=True, autoincrement=True) + update_date = Column(String, nullable=False) + sector_name = Column(String) + sector_code = Column(String) + stock_name = Column(String) + stock_code = Column(String) diff --git a/DB/model/StockDaily.py b/DB/model/StockDaily.py new file mode 100644 index 0000000..bf0abf1 --- /dev/null +++ b/DB/model/StockDaily.py @@ -0,0 +1,58 @@ +# 定义表格映射类 +# 定义表格映射类 +from sqlalchemy import Column, Integer, String, Float +from sqlalchemy.orm import declarative_base + + +def get_stock_daily(table_name): + class StockDaily(declarative_base()): + __tablename__ = table_name + + id = Column(Integer, primary_key=True, autoincrement=True) + ts_code = Column(String, nullable=False) # ts代码 + trade_date = Column(String, nullable=False) # 交易日期 + crncy_code = Column(String) # 货币代码 + pre_close = Column(Float) # 昨收盘价(元) + open = Column(Float) # 开盘价(元) + high = Column(Float) # 最高价(元) + low = Column(Float) # 最低价(元) + close = Column(Float) # 收盘价(元) + change = Column(Float) # 涨跌(元) + pct_chg = Column(Float) # 涨跌幅(%) + volume = Column(Float) # 成交量(手) + amount = Column(Float) # 成交金额(千元) + adj_pre_close = Column(Float) # 复权昨收盘价(元) + adj_open = Column(Float) # 复权开盘价(元) + adj_high = Column(Float) # 复权最高价(元) + adj_low = Column(Float) # 复权最低价(元) + adj_close = Column(Float) # 复权收盘价(元) + adj_factor = Column(Float) # 复权因子 + avg_price = Column(Float) # 均价(VWAP) + trade_status = Column(String) # 交易状态 + turnover_rate = Column(Float) # 换手率 + turnover_rate_f = Column(Float) # 换手率(自由流通股) + volume_ratio = Column(Float) # 量比 + pe = Column(Float) # 市盈率(总市值/净利润) + pe_ttm = Column(Float) # 市盈率(TTM) + pb = Column(Float) # 市净率(总市值/净资产) + ps = Column(Float) # 市销率 + ps_ttm = Column(Float) # 市销率(TTM) + dv_ratio = Column(Float) # 股息率 + dv_ttm = Column(Float) # 滚动股息率 + total_share = Column(Float) # 总股本 + float_share = Column(Float) # 流通股本 + free_share = Column(Float) # 自由流通股本 + total_mv = Column(Float) # 总市值 + circ_mv = Column(Float) # 流通市值 + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __repr__(self): + return f"" + + def to_dict(self): + return {"ts_code": self.ts_code, "trade_date": self.trade_date, "pre_close": self.pre_close, + "turnover_rate": self.turnover_rate} + + return StockDaily diff --git a/DB/model/StockDailyFreq.py b/DB/model/StockDailyFreq.py new file mode 100644 index 0000000..35434c3 --- /dev/null +++ b/DB/model/StockDailyFreq.py @@ -0,0 +1,26 @@ +from sqlalchemy import Column, Integer, String, Float +from sqlalchemy.orm import declarative_base + + +def get_stock_daily_freq(table_name): + class StockDailyFreq(declarative_base()): + __tablename__ = table_name + + id = Column(Integer, primary_key=True) + trade_date = Column(String) # 交易日期 + time = Column(String) # 时间 + open = Column(Float) # 开盘价 + close = Column(Float) # 收盘价 + high = Column(Float) # 最高价 + low = Column(Float) # 最低价 + vol = Column(Float) # 成交量(注意单位:手) + amount = Column(Float) # 成交额 + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def to_dict(self): + return {"trade_date": self.trade_date, "time": self.time, + "close": self.close} + + return StockDailyFreq diff --git a/DB/model/__init__.py b/DB/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/DB/mysql_db_main.py b/DB/mysql_db_main.py new file mode 100644 index 0000000..b18ef5d --- /dev/null +++ b/DB/mysql_db_main.py @@ -0,0 +1,26 @@ +from DB.db_main import * + + +def create_mysql_engine(): + # 创建引擎 + return create_engine( + f"mysql+pymysql://{config.mysql['username']}:{config.mysql['password']}@{config.mysql['url']}:{config.mysql['port']}/{config.mysql['database']}?charset=utf8mb4", + # "mysql+pymysql://tom@127.0.0.1:3306/db1?charset=utf8mb4", # 无密码时 + # 超过链接池大小外最多创建的链接 + max_overflow=0, + # 链接池大小 + pool_size=5, + # 链接池中没有可用链接则最多等待的秒数,超过该秒数后报错 + pool_timeout=10, + # 多久之后对链接池中的链接进行一次回收 + pool_recycle=1, + # 查看原生语句(未格式化) + echo=True + ) + + +class MysqlDbMain(DbMain): + def __init__(self): + super().__init__() + self.engine = create_mysql_engine() + self.session = self.get_session() diff --git a/DB/sqlite_db_main.py b/DB/sqlite_db_main.py new file mode 100644 index 0000000..7185ca6 --- /dev/null +++ b/DB/sqlite_db_main.py @@ -0,0 +1,50 @@ +from pathlib import Path + +from DB.db_main import * +import platform +import logging + + +class SqliteDbMain(DbMain): + def __init__(self, database_name): + # logging.basicConfig() + # logging.getLogger('sqlalchemy.engine').setLevel(logging.ERROR) + # logging.getLogger('sqlalchemy.pool').setLevel(logging.ERROR) + # logging.getLogger('sqlalchemy.dialects').setLevel(logging.ERROR) + # logging.getLogger('sqlalchemy.orm').setLevel(logging.ERROR) + self.database_name = database_name + super().__init__() + self.engine = self.__create_sqlite_engine() + self.engine_path = self.__get_path() + self.session = self.get_session() + self.inspector = inspect(self.engine) + + def __get_path(self): + sys_platform = platform.platform().lower() + print(f'当前操作系统:{platform.platform()}') + __engine = '' + if 'windows' in sys_platform.lower(): + __engine = f"E:\\sqlite_db\\stock_db\\{self.database_name}" + elif 'macos' in sys_platform.lower(): + __engine = f"/Users/renmeng/Documents/sqlite_db/{self.database_name}" + else: + __engine = f"{self.database_name}" + return __engine + + def __create_sqlite_engine(self): + sys_platform = platform.platform().lower() + __engine = '' + if 'windows' in sys_platform.lower(): + __engine = f"sqlite:///E:\\sqlite_db\\stock_db\\{self.database_name}" + elif 'macos' in sys_platform.lower(): + __engine = f"sqlite:////Users/renmeng/Documents/sqlite_db/{self.database_name}" + else: + __engine = f"sqlite:///{self.database_name}" + print(f"当前__engine是:{__engine}") + return create_engine(__engine, pool_size=10, pool_timeout=10, echo=True) + + def get_db_size(self): + file_size = Path(self.engine_path).stat().st_size + total = f"{file_size / (1024 * 1024):.2f} MB" + print(f"文件大小: {file_size / (1024 * 1024):.2f} MB") + return total diff --git a/fp/分钟线数据入库.py b/fp/分钟线数据入库.py new file mode 100644 index 0000000..4ea93f0 --- /dev/null +++ b/fp/分钟线数据入库.py @@ -0,0 +1,119 @@ +from datetime import timedelta + +from DB.model.StockDailyFreq import get_stock_daily_freq +from DB.sqlite_db_main import SqliteDbMain, config +from utils.tdxUtil import TdxUtil +from utils.comm import * +from 基本信息入库 import StockInfoMain + + +class StockDailyFreqMain: + + def __init__(self): + self.code_res = StockInfoMain().get_stock_basic() + self.db_main = SqliteDbMain(config.stock_daily_freq_db) + + def __filter_stock(self, code, name): + if 'ST' in name: + return False + if code.startswith('30'): + return False + if code.startswith('68'): + return False + return True + + def init_data(self): + """ + 全量入库操作 + :return: + """ + # Print the results + for result in self.code_res: + tdx_util = TdxUtil("") + s_type = tdx_util.get_security_type(code=result.ts_code) + if s_type in tdx_util.SECURITY_TYPE and self.__filter_stock(result.ts_code, result.name): + # 设置开始和结束时间 + start_time = f"{(datetime.now() - timedelta(days=365 * 3)).strftime('%Y-%m-%d')} 09:30:00" + end_time = f"{datetime.now().strftime('%Y-%m-%d')} 15:00:00" + print( + f"{result.id},{result.ts_code} --> start_time={trade_date} 09:30:00 --->> end_time={trade_date} 15:00:00") + i = 0 + df = None + while True: + try: + if i > 6: + break + df = xcsc_pro.stk_mins(ts_code=result.ts_code, freq='30min', + start_time=start_time, + end_time=end_time)[::-1] + except Exception as e: + i += 1 + # 捕获超时异常 + print("请求超时,等待2分钟后重试...") + time.sleep(120) # 休眠2分钟 + print(f"{result.ts_code},{result.name} 获取 stk_mins _ 30 数据完毕!") + # 创建表 + table_name = str(result.ts_code).split(".")[0] + "_daily_freq_30" + new_table_class = get_stock_daily_freq(table_name=table_name) + self.db_main.create_table(new_table_class) + entries = [] + for index, row in df.iterrows(): + min_time = row['trade_time'] + # 将字符串解析为日期时间对象 + trade_time = datetime.strptime(min_time, '%Y-%m-%d %H:%M:%S').strftime('%Y-%m-%d') + entry = new_table_class( + trade_date=trade_time, + time=min_time, + open=row['open'], + close=row['close'], + high=row['high'], + low=row['low'], + vol=row['vol'], + amount=row['amount'], + ) + print(entry.to_dict()) + entries.append(entry) + self.db_main.insert_all_entry(entries) + + def task_data(self, trade_date=datetime.now().strftime('%Y%m%d')): + print(len(self.code_res)) + for result in self.code_res: + tdx_util = TdxUtil("") + s_type = tdx_util.get_security_type(code=result.ts_code) + if s_type in tdx_util.SECURITY_TYPE and self.__filter_stock(result.ts_code, result.name): + print( + f"{result.id},{result.ts_code} --> start_time={trade_date} 09:30:00 --->> end_time={trade_date} 15:00:00") + + df = xcsc_pro.stk_mins(ts_code=result.ts_code, freq='30min', + start_time=f"{trade_date} 09:30:00", + end_time=f"{trade_date} 15:00:00") + print(f"{result.ts_code},{result.name} 获取 stk_mins _ 30 数据完毕!") + for index, row in df.iterrows(): + table_name = str(result.ts_code).split(".")[0] + "_daily_freq_30" + new_table_class = get_stock_daily_freq(table_name=table_name) + entry = new_table_class( + trade_date=trade_date, + time=row['trade_time'], + open=row['open'], + close=row['close'], + high=row['high'], + low=row['low'], + vol=row['vol'], + amount=row['amount'], + ) + print(entry.to_dict()) + self.db_main.insert_entry(entry) + + +if __name__ == '__main__': + current_date = datetime.now() + if if_run(current_date): + trade_date = current_date.strftime('%Y-%m-%d') + print(trade_date) + main = StockDailyFreqMain() + # main.init_data() + main.task_data(trade_date) + +# df = xcsc_pro.stk_mins(ts_code='600095.SH', start_time='2022-08-24 09:30:00', end_time='2023-08-24 15:00:00', +# freq='30min') +# print(df) diff --git a/fp/基本信息入库.py b/fp/基本信息入库.py new file mode 100644 index 0000000..213e65d --- /dev/null +++ b/fp/基本信息入库.py @@ -0,0 +1,71 @@ +from sqlalchemy import and_ + +from DB.model.StockBasic import StockBasic +from DB.sqlite_db_main import SqliteDbMain, config +from utils.comm import * + + +class StockInfoMain: + + def __init__(self): + self.db_main = SqliteDbMain(config.stock_info_db) + self.session = self.db_main.get_session() + + def insert_stock_basic(self): + pro = xcsc_pro + info = ['SZSE', 'SSE'] + # entries = [] + for item in info: + # SSE: 上交所 SZSE: 深交所 + df = pro.stock_basic(exchange=item) + for index, row in df.iterrows(): + entry = StockBasic(**row) + self.db_main.insert_or_update(entry, query_conditions={'ts_code': row['ts_code']}) + + def get_stock_basic(self, ts_code=None, symbol=None, restart_id=0): + sql_cond = and_( + StockBasic.delist_date.is_('None') + ) + try: + if restart_id > 0: + sql_cond = and_( + StockBasic.id >= restart_id, + StockBasic.delist_date.is_('None') + ) + if ts_code is not None: + if isinstance(ts_code, str): + sql_cond = and_( + StockBasic.ts_code.is_(ts_code), + StockBasic.delist_date.is_('None') + ) + elif isinstance(ts_code, list): + sql_cond = and_( + StockBasic.ts_code.in_(ts_code), + StockBasic.delist_date.is_('None') + ) + elif symbol is not None: + if isinstance(symbol, str): + sql_cond = and_( + StockBasic.symbol.is_(symbol), + StockBasic.delist_date.is_('None') + ) + elif isinstance(symbol, list): + sql_cond = and_( + StockBasic.ts_code.in_(symbol), + StockBasic.delist_date.is_('None') + ) + results = self.session.query(StockBasic).filter(sql_cond).all() + return results + finally: + self.session.close() + + +# try: +# results = self.session.query(StockBasic).filter(StockBasic.delist_date.is_(None)).all() +# return results +# finally: +# self.session.close() + + +if __name__ == '__main__': + print(StockInfoMain().get_stock_basic(restart_id=5)) diff --git a/fp/新闻资讯.py b/fp/新闻资讯.py new file mode 100644 index 0000000..69bdbd8 --- /dev/null +++ b/fp/新闻资讯.py @@ -0,0 +1,124 @@ +from utils.comm import * +from snownlp import SnowNLP +import jieba +import jieba.posseg as pseg +import jieba.analyse + + +class News: + def __init__(self, trade_date=datetime.now()): + self.trade_date = trade_date.strftime('%Y%m%d') + + @staticmethod + def preprocess(text): + seg_list = jieba.lcut(text) # 文本预处理及分词 + return seg_list + + @staticmethod + def sentiment_analysis(text): # 情感分析 + s = SnowNLP(text) + sentiment = s.sentiments # 情感得分,范围0-1,越接近1表示正面情感,越接近0表示负面情感 + return sentiment + + @staticmethod + def extract_keywords(text, top_k=5): # 关键词提取 + keywords = jieba.analyse.extract_tags(text, topK=top_k, withWeight=False, allowPOS=()) # 提取关键词,默认返回前5个 + return keywords + + @staticmethod + def named_entity_recognition(text): # 命名实体识别 + words = pseg.cut(text) + entities = [] + for word, flag in words: + # 识别人名、地名、机构名、其它专名 + if flag in ['nr', 'ns', 'nt', 'nz']: + entities.append((word, flag)) + return entities + + def result(self, news): + print("原文:", news) + preprocessed_news = self.preprocess(news) + sentiment_score = self.sentiment_analysis(news) + keywords = self.extract_keywords(news, top_k=3) + entities = self.named_entity_recognition(news) + + print("分词结果:", preprocessed_news) + print("情感分析得分:", round(sentiment_score, 2)) + print("关键词提取:", keywords) + print("命名实体识别:", entities) + + def news_cctv(self): + for _ in range(5): + try: + # 中央新闻 + df = ak.news_cctv(date=self.trade_date) + # 情感分析 + # df['title'].apply(lambda x: self.result(x)) + # print(df['情感']) + return df + except Exception as e: + print(f"{self.trade_date} 日的新闻咨询拉取发生错误!{e.__traceback__}") + time.sleep(1) + continue + else: + print("中央新闻 5次出现错误,请关注!!!") + return None + + def news_cls(self): + for _ in range(5): + try: + # 财联社-电报 + df = ak.stock_telegraph_cls() + # 情感分析 + # df['title'].apply(lambda x: self.result(x)) + # print(df['情感']) + return df + except Exception as e: + print(f"{self.trade_date} 财联社-电报 咨询拉取发生错误!{e.__traceback__}") + continue + else: + print("财联社-电报 5次出现错误,请关注!!!") + return None + + @staticmethod + def news_stock_by_code(symbol, date_time=None): + # 个股新闻 + stock_news_em_df = ak.stock_news_em(symbol=symbol) + # 将发布时间列转换为日期时间类型 + stock_news_em_df['发布时间'] = pd.to_datetime(stock_news_em_df['发布时间']) + # print(stock_news_em_df) + if date_time is not None: + if len(date_time.split("-")) == 1: + str_date = datetime.strptime(str(date_time), '%Y%m%d').strftime('%Y-%m-%d') + else: + str_date = datetime.strptime(str(date_time), '%Y-%m-%d').strftime('%Y-%m-%d') + else: + # 获取今天的日期 + str_date = datetime.now().date() + # 使用日期过滤器筛选出今天的新闻 + today_news = stock_news_em_df[stock_news_em_df['发布时间'].dt.date == str_date] + # 打印今天的新闻 + print(today_news['新闻标题']) + print(today_news["新闻内容"]) + # today_news['新闻标题'].apply(lambda x: result(x)) + return today_news + + def html_page_data(self): + cctv_df = self.news_cctv() + cls_df = self.news_cls() + cctv_content, cls_content = '', '' + for index, row in cctv_df.iterrows(): + cctv_content += f'{row["title"]}' \ + f'开发中...' + cctv_content += '' + for index, row in cls_df.iterrows(): + if len(row["标题"]) > 1: + cls_content += f'{row["标题"]}' \ + f'开发中...' + cls_content += '' + return cctv_content, cls_content + + +if __name__ == '__main__': + date_obj = datetime.strptime('20231019', "%Y%m%d") + print(News().news_cctv()) diff --git a/utils/comm.py b/utils/comm.py new file mode 100644 index 0000000..d27340c --- /dev/null +++ b/utils/comm.py @@ -0,0 +1,285 @@ +import ftplib +import os +from datetime import datetime +from ftplib import FTP, error_perm +from pathlib import Path + +import akshare as ak +import efinance as ef +import qstock as qs +import tushare as ts +import pandas as pd +import time +import random +from tabulate import tabulate +import platform +import xcsc_tushare as xc + +token = '0718534658b9d91b3f03dc8b220e4062193ebf6f6414d036505165e1' + +xc.set_token('bd4f26f1eca8d660bd23e229260df46002d630d6e1fb9226380edec8') +# 仿真环境用这个方式可以连上,生产环境用这个连不上 +xcsc_pro = xc.pro_api(env='prd', server='http://116.128.206.39:7172') + +# 显示所有列 +pd.set_option('display.max_columns', None) +# 显示所有行 +pd.set_option('display.max_rows', None) +# 输出不折行 +pd.set_option('expand_frame_repr', False) +# 最大列宽度 +pd.set_option('display.max_colwidth', None) + + +def sleep(): + print("开始休眠,防止ip被拉黑或者进小黑屋") + # 生成随机的休眠时间 + sleep_time = random.uniform(0, 2) + # 执行休眠 + time.sleep(sleep_time) + print("休眠结束!") + + +def print_markdown(df: pd): + print(tabulate(df, headers='keys', tablefmt='github')) + + +def get_file(size): + """ + 获取不同系统下的文件路径 + :param size: + :return: + """ + sys_platform = platform.platform().lower() + if "macos" in sys_platform: + return f'/Users/renmeng/work_space/python_work/qnloft-get-web-everything/股票金融/量化交易/股票数据/{size}' + elif "windows" in sys_platform: + return f'E:\\Python_Workplace\qnloft-get-web-everything\\股票金融\\量化交易\\股票数据\\{size}' + else: + print("其他系统") + + +def write_file(content, file_name, mode_type='a'): + file = "" + sys_platform = platform.platform().lower() + if "macos" in sys_platform: + file = f'/Users/renmeng/work_space/python_work/qnloft-get-web-everything/股票金融/量化交易/股票数据/{file_name}' + elif "windows" in sys_platform: + file = f'D:\\文档\\数据测试\\{file_name}' + else: + print("其他系统") + path = Path(file) + # 创建目录或文件 + if not path.exists(): + if path.is_dir(): + path.mkdir(parents=True) # 创建目录及其父目录 + else: + path.touch() # 创建文件 + with path.open(mode=mode_type) as file: + file.write(content) + + +def create_file(file_name): + path = Path(file_name) + # 创建目录或文件 + if not path.exists(): + path.mkdir(parents=True) # 创建目录及其父目录 + + +def get_random_stock(size, n=0): + if size == 'M': + s_list = Path(get_file(size)).read_text().splitlines() + elif size == 'S': + s_list = Path(get_file(size)).read_text().splitlines() + elif size == 'L': + s_list = Path(get_file(size)).read_text().splitlines() + else: + raise ValueError("size输入错误!") + if n > 0: + return random.sample(s_list, n) + else: + return s_list + + +def del_file_lines(size, to_delete): + content = Path(get_file(size)).read_text() + # 删除匹配的数字 + for char in to_delete: + content = content.replace(char, '') + + # 将更新后的内容写回文件 + with Path(get_file(size)).open(mode='w') as file: + file.writelines(content) + + +def if_run(current_date=datetime.now()): + """ + 判断当前日期是否是交易日 + :param current_date: + :return: + """ + df = ak.tool_trade_date_hist_sina() + # 将日期列转换为 datetime 类型 + df['trade_date'] = pd.to_datetime(df['trade_date']) + date_now_hour = current_date.hour + # 将日期格式化为 "yyyy-mm-dd" 形式 + formatted_date = current_date.strftime('%Y-%m-%d') + # 检查 DataFrame 是否包含特定日期 + contains_target_date = (df['trade_date'] == formatted_date).any() + print(f'当日日期是:{formatted_date} ,今日是否开盘:{contains_target_date}') + if contains_target_date: + return True + return False + + +def return_trading_day(now_date, offset_date, format_type='%Y%m%d'): + """ + 获取当前时间的 前n个交易日期,后n个交易日期 + :param format_type: + :param offset_date: + :param now_date: + :return: + """ + df = ak.tool_trade_date_hist_sina() + # 将日期列转换为 datetime 类型 + df['trade_date'] = pd.to_datetime(df['trade_date']) + if isinstance(now_date, datetime): + formatted_date = now_date.strftime('%Y-%m-%d') + else: + formatted_date = now_date + selected_indexes = df[df['trade_date'] == formatted_date].index + + return df.loc[selected_indexes + offset_date, 'trade_date'].dt.strftime(format_type).values[0] + + +def upload_files_to_ftp(local_directory, remote_directory, file_name=None): + """ + ftp上传 + :param local_directory: + :param remote_directory: + :param file_name: + :return: + """ + try: + ftp = FTP("qxu1142200198.my3w.com") + ftp.login('qxu1142200198', '48XZ55MB') + except error_perm as e: + print(f"Network error: {e}") + return False + try: + # 切换目录 + ftp.cwd(remote_directory) + except ftplib.error_perm: + print("目录不存在, 开始创建...") + ftp.mkd(remote_directory) + try: + path = Path(local_directory) + # 判断是文件夹还是文件 + if path.is_dir() and file_name is None: + files = os.listdir(local_directory) # get list of files in local directory + for file in files: + if file.endswith(".html"): # upload only .txt files + local_file = f"{local_directory}/{file}" + remote_file = f"{remote_directory}/{file}" + print(local_file, "----->>", remote_file) + with open(local_file, 'rb') as f: + try: + ftp.delete(remote_file) + except ftplib.error_perm: + print("目录不存在, 或已经删除...") + ftp.storbinary(f"STOR {remote_file}", f) # upload file to FTP server + if file_name is not None and len(file_name) > 0: + local_file = f"{local_directory}/{file_name}" + remote_file = f"{remote_directory}/{file_name}" + print(local_file, "----->>", remote_file) + with open(local_file, 'rb') as f: + try: + ftp.delete(remote_file) + except ftplib.error_perm: + print("目录不存在, 或已经删除...") + ftp.storbinary(f"STOR {remote_file}", f) # upload file to FTP server + finally: + ftp.quit() + return True + + +def back_testing(df, cond: str, n=3): + """ + 回撤测试 + :param df: + :param cond: + :param n: + :return: + """ + filtered_df = df.eval(cond) + total, tomorrow_rise, tomorrow_fall = 0, 0, 0 + res_df = pd.DataFrame( + columns=['code', '日期', '明天涨跌幅', 'n日最大涨幅', 'n日平均涨跌幅', 'n日最大回撤']) + for index, row in df.iterrows(): + if filtered_df[index] and index > n: + # print(f'{row["trade_date"]} --> {row["KDJ_D"]} --> {row["KDJ_J"]}') + total += 1 + # 明天涨跌幅情况 + tomorrow_pre = df.iloc[-index:-index + 1]['pct_chg'].values[0].round(2) + # n日 最大涨幅 + max_h = df.iloc[-index:-index + n]['pct_chg'].max().round(2) + # n日 平均涨跌幅 + mean_h = df.iloc[-index:-index + n]['pct_chg'].mean().round(2) + # n日 最大回撤 + h_max = df.iloc[-index:-index + n]['high'].max() + l_min = df.iloc[-index:-index + n]['low'].min() + max_ret = f"{(h_max - l_min) / l_min:.2f}" + res_df.loc[total, 'code'], res_df.loc[total, '日期'] = row['ts_code'], row["trade_date"] + res_df.loc[total, '明天涨跌幅'] = tomorrow_pre + res_df.loc[total, 'n日最大涨幅'] = max_h + res_df.loc[total, 'n日平均涨跌幅'] = mean_h + res_df.loc[total, 'n日最大回撤'] = max_ret + return res_df + + +def buying_and_selling_decisions(df, cond_buy: str, cond_sell: str, cond_buy_read=None, cond_sell_read=None): + """ + 买卖测试 + :param df: + :param cond_buy: + :param cond_sell: + :param cond_buy_read: + :param cond_sell_read: + :return: + """ + total, buy, sell, total_chg, buy_date, flag_ready_b, flag_ready_s = 0, 0, 0, 0, 0, False, False + res_df = pd.DataFrame( + columns=['code', '买入日期', '卖出日期', '盈利', '持仓天数']) + for index, row in df.iterrows(): + trade_date, c = row["trade_date"], row['close'] + if (eval(cond_buy_read) or cond_buy_read is None) and buy == 0: + # 准备买入 + print(f"{trade_date} {cond_buy_read},准备买入!") + flag_ready_b = True + if eval(cond_buy) and buy == 0 and flag_ready_b: + # 买入 + print(f"{trade_date} {cond_buy},买入操作!") + buy, buy_date = c, trade_date + flag_ready_b = False + if (eval(cond_sell_read) or cond_sell_read is None) and buy > 0: + # 准备卖出 + print(f"{trade_date} {cond_sell_read},准备卖出!") + flag_ready_s = True + if eval(cond_sell) and buy > 0 and flag_ready_s: + print(f"{trade_date} {cond_sell},卖出操作!") + total += 1 + # 卖出 + sell = c + # 计算盈利 + chg = round(((sell - buy) / buy) * 100, 2) + # 计算持仓天数 + start_date = datetime.strptime(buy_date, "%Y%m%d") + end_date = datetime.strptime(trade_date, "%Y%m%d") + # 计算间隔天数 + interval_days = (end_date - start_date).days + buy, flag_ready_s = 0, False + res_df.loc[total, 'code'], res_df.loc[total, '买入日期'] = row['ts_code'], buy_date + res_df.loc[total, '卖出日期'] = trade_date + res_df.loc[total, '盈利'] = chg + res_df.loc[total, '持仓天数'] = interval_days + return res_df diff --git a/utils/formula.py b/utils/formula.py new file mode 100644 index 0000000..dab2967 --- /dev/null +++ b/utils/formula.py @@ -0,0 +1,691 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import numpy as np +import pandas as pd + + +def EMA(number, n): + return pd.Series(number).ewm(alpha=2 / (n + 1), adjust=True).mean() + + +def MA(number, n): + return pd.Series.rolling(number, n).mean() + + +def SMA(number, n, m=1): + _df = number.fillna(0) + return pd.Series(_df).ewm(com=n - m, adjust=True).mean() + + +def RM_SMA(DF, N, M): + DF = DF.fillna(0) + z = len(DF) + var = np.zeros(z) + var[0] = DF[0] + for i in range(1, z): + var[i] = (DF[i] * M + var[i - 1] * (N - M)) / N + for i in range(z): + DF[i] = var[i] + return DF + + +def ATR(close, high, low, n): + """ + 真实波幅 + :param close: + :param high: + :param low: + :param n: + :return: + """ + c, h, l_ = close, high, low + mtr = MAX(MAX((h - l_), ABS(REF(c, 1) - h)), ABS(REF(c, 1) - l_)) + atr = MA(mtr, n) + return pd.DataFrame({'MTR': mtr, 'ATR': atr}) + + +def HHV(number, n): + return pd.Series.rolling(number, n).max() + + +def LLV(number, n): + return pd.Series.rolling(number, n).min() + + +def SUM(number, n): + return pd.Series.rolling(number, n).sum() + + +def ABS(number): + return np.abs(number) + + +def MAX(A, B): + return np.maximum(A, B) + + +def MIN(A, B): + var = IF(A < B, A, B) + return var + + +def IF(COND, V1, V2): + var = np.flip(np.where(COND, V1, V2)) + return pd.Series(var)[::-1] + + +def REF(DF, N): + var = DF.diff(N) + var = DF - var + return var + + +def STD(number, n): + return pd.Series.rolling(number, n).std() + + +def MACD(close, f, s, m): + """ + + :param close: + :param f: + :param s: + :param m: + :return: + """ + EMAFAST = EMA(close, f) + EMASLOW = EMA(close, s) + DIFF = EMAFAST - EMASLOW + DEA = EMA(DIFF, m) + MACD = (DIFF - DEA) * 2 + return pd.DataFrame({ + 'DIFF': round(DIFF, 2), + 'DEA': round(DEA, 2), 'MACD': round(MACD, 2)}) + + +def KDJ(close, high, low, n, m1, m2): + """ + + :param close: + :param high: + :param low: + :param n: + :param m1: + :param m2: + :return: + """ + c, h, l = close, high, low + RSV = (c - LLV(l, n)) / (HHV(h, n) - LLV(l, n)) * 100 + K = SMA(RSV, m1, 1) + D = SMA(K, m2, 1) + J = 3 * K - 2 * D + return pd.DataFrame({'KDJ_K': round(K, 2), 'KDJ_D': round(D, 2), 'KDJ_J': round(J, 2)}) + + +def OSC(close, n, m): + """ + 变动速率线 + :param close: + :param n: + :param m: + :return: + """ + c = close + OS = (c - MA(c, n)) * 100 + MAOSC = EMA(OS, m) + return pd.DataFrame({'OSC': OS, 'MAOSC': MAOSC}) + + +def BBI(close, N1, N2, N3, N4): + """ + 多空指标 + :param close: + :param N1: + :param N2: + :param N3: + :param N4: + :return: + """ + bbi = (MA(close, N1) + MA(close, N2) + MA(close, N3) + MA(close, N4)) / 4 + return pd.DataFrame({'BBI': round(bbi, 2)}) + + +def BBIBOLL(close, n, m, n1=3, n2=6, n3=12, n4=24): + """ + 多空布林线 + :param close: + :param n1: + :param n2: + :param n3: + :param n4: + :param n: + :param m: + :return: + """ + bbi_boll = BBI(close, n1, n2, n3, n4)['BBI'] + UPER = bbi_boll + m * STD(bbi_boll, n) + DOWN = bbi_boll - m * STD(bbi_boll, n) + return pd.DataFrame({'BBIBOLL': round(bbi_boll, 2), 'UPER': round(UPER, 2), 'DOWN': round(DOWN, 2)}) + + +def PBX(close, n1, n2, n3, n4, n5, n6): + """ + 瀑布线 + :param close: + :param n1: + :param n2: + :param n3: + :param n4: + :param n5: + :param n6: + :return: + """ + c = close + PBX1 = (EMA(c, n1) + MA(c, 2 * n1) + MA(c, 4 * n1)) / 3 + PBX2 = (EMA(c, n2) + MA(c, 2 * n2) + MA(c, 4 * n2)) / 3 + PBX3 = (EMA(c, n3) + MA(c, 2 * n3) + MA(c, 4 * n3)) / 3 + PBX4 = (EMA(c, n4) + MA(c, 2 * n4) + MA(c, 4 * n4)) / 3 + PBX5 = (EMA(c, n5) + MA(c, 2 * n5) + MA(c, 4 * n5)) / 3 + PBX6 = (EMA(c, n6) + MA(c, 2 * n6) + MA(c, 4 * n6)) / 3 + return pd.DataFrame( + {'PBX1': round(PBX1, 2), 'PBX2': round(PBX2, 2), 'PBX3': round(PBX3, 2), + 'PBX4': round(PBX4, 2), 'PBX5': round(PBX5, 2), 'PBX6': round(PBX6, 2)} + ) + + +def BOLL(close, N): # 布林线 + boll = MA(close, N) + UB = boll + 2 * STD(close, N) + LB = boll - 2 * STD(close, N) + return pd.DataFrame({'BOLL': round(boll, 2), 'UB': round(UB, 2), 'LB': round(LB, 2)}) + + +def ROC(close, n, m): + """ + 变动率指标 + :param close: + :param n: + :param m: + :return: + """ + c = close + roc = 100 * (c - REF(c, n)) / REF(c, n) + maroc = MA(roc, m) + return pd.DataFrame({'ROC': round(roc, 2), 'MAROC': round(maroc, 2)}) + + +def MTM(close, n, m): + """ + 动量线 + :param close: + :param n: + :param m: + :return: + """ + c = close + mtm = c - REF(c, n) + mtm_ma = MA(mtm, m) + return pd.DataFrame({'MTM': round(mtm, 2), 'MTMMA': round(mtm_ma, 2)}) + + +def MFI(close, high, low, vol, n): + """ + 资金指标 + :param close: + :param high: + :param low: + :param vol: + :param n: + :return: + """ + c, h, l, v = close, high, low, vol + TYP = (c + h + l) / 3 + V1 = SUM(IF(TYP > REF(TYP, 1), TYP * v, 0), n) / \ + SUM(IF(TYP < REF(TYP, 1), TYP * v, 0), n) + mfi = 100 - (100 / (1 + V1)) + return pd.DataFrame({'MFI': round(mfi, 2)}) + + +def SKDJ(close, high, low, N, M): + c = close + LOWV = LLV(low, N) + HIGHV = HHV(high, N) + RSV = EMA((c - LOWV) / (HIGHV - LOWV) * 100, M) + K = EMA(RSV, M) + D = MA(K, M) + return pd.DataFrame({'SKDJ_K': round(K, 2), 'SKDJ_D': round(D, 2)}) + + +def WR(close, high, low, N, N1): + """ + 威廉指标 + :param close: + :param high: + :param low: + :param N: + :param N1: + :return: + """ + c, h, l = close, high, low + WR1 = round(100 * (HHV(h, N) - c) / (HHV(h, N) - LLV(l, N)), 2) + WR2 = round(100 * (HHV(h, N1) - c) / (HHV(h, N1) - LLV(l, N1)), 2) + return pd.DataFrame({'WR1': round(WR1, 2), 'WR2': round(WR2, 2)}) + + +def BIAS(DF, N1, N2, N3): # 乖离率 + CLOSE = DF + BIAS1 = (CLOSE - MA(CLOSE, N1)) / MA(CLOSE, N1) * 100 + BIAS2 = (CLOSE - MA(CLOSE, N2)) / MA(CLOSE, N2) * 100 + BIAS3 = (CLOSE - MA(CLOSE, N3)) / MA(CLOSE, N3) * 100 + DICT = {'BIAS1': BIAS1, 'BIAS2': BIAS2, 'BIAS3': BIAS3} + VAR = pd.DataFrame(DICT) + return VAR + + +def RSI(c, N1, N2, N3): # 相对强弱指标RSI1:SMA(MAX(CLOSE-LC,0),N1,1)/SMA(ABS(CLOSE-LC),N1,1)*100; + DIF = c - REF(c, 1) + RSI1 = round((SMA(MAX(DIF, 0), N1) / round(SMA(ABS(DIF), N1) * 100, 3)) * 10000, 2) + RSI2 = round((SMA(MAX(DIF, 0), N2) / round(SMA(ABS(DIF), N2) * 100, 3)) * 10000, 2) + RSI3 = round((SMA(MAX(DIF, 0), N3) / round(SMA(ABS(DIF), N3) * 100, 3)) * 10000, 2) + return pd.DataFrame({'RSI1': RSI1, 'RSI2': RSI2, 'RSI3': RSI3}) + + +def ADTM(DF, N, M): # 动态买卖气指标 + HIGH = DF['high'] + LOW = DF['low'] + OPEN = DF['open'] + DTM = IF(OPEN <= REF(OPEN, 1), 0, MAX( + (HIGH - OPEN), (OPEN - REF(OPEN, 1)))) + DBM = IF(OPEN >= REF(OPEN, 1), 0, MAX((OPEN - LOW), (OPEN - REF(OPEN, 1)))) + STM = SUM(DTM, N) + SBM = SUM(DBM, N) + ADTM1 = IF(STM > SBM, (STM - SBM) / STM, + IF(STM == SBM, 0, (STM - SBM) / SBM)) + MAADTM = MA(ADTM1, M) + DICT = {'ADTM': ADTM1, 'MAADTM': MAADTM} + VAR = pd.DataFrame(DICT) + return VAR + + +def DDI(DF, N, N1, M, M1): # 方向标准离差指数 + H = DF['high'] + L = DF['low'] + DMZ = IF((H + L) <= (REF(H, 1) + REF(L, 1)), 0, + MAX(ABS(H - REF(H, 1)), ABS(L - REF(L, 1)))) + DMF = IF((H + L) >= (REF(H, 1) + REF(L, 1)), 0, + MAX(ABS(H - REF(H, 1)), ABS(L - REF(L, 1)))) + DIZ = SUM(DMZ, N) / (SUM(DMZ, N) + SUM(DMF, N)) + DIF = SUM(DMF, N) / (SUM(DMF, N) + SUM(DMZ, N)) + ddi = DIZ - DIF + ADDI = SMA(ddi, N1, M) + AD = MA(ADDI, M1) + DICT = {'DDI': ddi, 'ADDI': ADDI, 'AD': AD} + VAR = pd.DataFrame(DICT) + return VAR + + +ZIG_STATE_START = 0 +ZIG_STATE_RISE = 1 +ZIG_STATE_FALL = 2 + + +def ZIG(d, k, n): + """ + 之字转向指标,当前价格变化超过 x% 时候变化 + :param d: 交易日期 + :param k: 价格 + :param n: 系数 + :return: + """ + x = round(n / 100, 2) + peer_i = 0 + candidate_i = None + scan_i = 0 + peers = [0] + z = np.zeros(len(k)) + state = ZIG_STATE_START + while True: + scan_i += 1 + if scan_i == len(k) - 1: + # 扫描到尾部 + if candidate_i is None: + peer_i = scan_i + peers.append(peer_i) + else: + if state == ZIG_STATE_RISE: + if k[scan_i] >= k[candidate_i]: + print(d[scan_i], "1 --->>>", d[candidate_i]) + peer_i = scan_i + peers.append(peer_i) + else: + peer_i = candidate_i + peers.append(peer_i) + peer_i = scan_i + peers.append(peer_i) + elif state == ZIG_STATE_FALL: + if k[scan_i] <= k[candidate_i]: + print(d[scan_i], "2 --->>>", d[candidate_i]) + peer_i = scan_i + peers.append(peer_i) + else: + peer_i = candidate_i + peers.append(peer_i) + peer_i = scan_i + peers.append(peer_i) + break + if state == ZIG_STATE_START: + if k[scan_i] >= k[peer_i] * (1 + x): + print(d[scan_i], "3 --->>>", d[peer_i]) + candidate_i = scan_i + state = ZIG_STATE_RISE + elif k[scan_i] <= k[peer_i] * (1 - x): + print(d[scan_i], "4 --->>>", d[peer_i]) + candidate_i = scan_i + state = ZIG_STATE_FALL + elif state == ZIG_STATE_RISE: + if k[scan_i] >= k[candidate_i]: + candidate_i = scan_i + elif k[scan_i] <= k[candidate_i] * (1 - x): + print(d[scan_i], "5 --->>>", d[candidate_i]) + peer_i = candidate_i + peers.append(peer_i) + state = ZIG_STATE_FALL + candidate_i = scan_i + elif state == ZIG_STATE_FALL: + if k[scan_i] <= k[candidate_i]: + print(d[scan_i], "6 --->>>", d[candidate_i]) + candidate_i = scan_i + elif k[scan_i] >= k[candidate_i] * (1 + x): + print(d[scan_i], "7 --->>>", d[candidate_i]) + peer_i = candidate_i + peers.append(peer_i) + state = ZIG_STATE_RISE + candidate_i = scan_i + for i in range(len(peers) - 1): + peer_start_i = peers[i] + peer_end_i = peers[i + 1] + start_value = k[peer_start_i] + end_value = k[peer_end_i] + a = (end_value - start_value) / (peer_end_i - peer_start_i) # 斜率 + for j in range(peer_end_i - peer_start_i + 1): + z[j + peer_start_i] = start_value + a * j + return pd.Series(z), peers + + +def TROUGHBARS(z, p, m): + """ + 前 m 个 zig 波谷到当前的距离 + :param z: zig 指标 + :param p: zip 转折点 + :param m: 系数 + :return: + """ + trough_bars = np.zeros(len(z)) + if len(z) > 3: + j = 1 + # 判断第一个是谷还是峰 ,峰则取偶数,如果是谷,则取奇数 + if z[0] > z[1]: + # 第一个是波谷 + for i in range(len(p)): + peer = p[i] + j = i + m * 2 + if 0 < i and len(p) > j and i % 2 == 1: + num = p[j] - peer - 1 + trough_bars[p[j] - 1] = num + trough_bars[p[j]] = 1 + if z[0] < z[1]: + # 第一个是波峰 + for i in range(len(p)): + peer = p[i] + j = i + m * 2 + if 0 < i and len(p) > j and i % 2 == 0: + num = p[j] - peer - 1 + trough_bars[p[j] - 1] = num + trough_bars[p[j]] = 1 + return pd.Series(trough_bars) + + +def CROSS(a, b): + """ + 穿越信号 + 当a向上穿越b时,标记1;当a向下穿越b时,标记-1;没穿越标记0 + :param obj: + :param ref: + :return: + """ + assert len(a) == len(b), '穿越信号输入维度不相等' + assert len(a) > 1, '穿越信号长度至少为2' + res = np.zeros(len(a)) + for i in range(len(a) - 2, -1, -1): + if a[i + 1] <= b[i + 1] and a[i] > b[i] and a[i + 1] < a[i]: + # 向上穿越时,标记1 + res[i] = 1 + elif a[i + 1] >= b[i + 1] and a[i] < b[i] and a[i + 1] > a[i]: + res[i] = -1 + else: + res[i] = 0 + # print(f"a+1 = {a[i + 1]}, b+1 = {b[i + 1]} , a = {a[i]} , b = {b[i]} , res = {res[i]}") + return pd.Series(res) + + +def _calc_slope(x): + return np.polyfit(range(len(x)), x, 1)[0] + + +def rolling_window(a, window): + shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) + strides = a.values.strides + (a.values.strides[-1],) + return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) + + +def SLOPE(series, n): + """ + SLOPE(X,N) 返回线性回归斜率,N支持变量 + 参考:https://blog.csdn.net/luhouxiang/article/details/113816062 + """ + a = rolling_window(series, n) + obj = np.array([_calc_slope(x) for x in a]) + new_obj = np.pad(obj, (len(series) - len(obj), 0), 'constant', constant_values=(np.nan, np.nan)) + return new_obj + + +def GOLD_MACD(df_data: pd.DataFrame): + """ + 黄金MACD指标 + :param df_data: + :return: + """ + df = df_data[::-1].reset_index(drop=True) + CLOSE = df["close"] + d = df["trade_date"] + MACD = (EMA(CLOSE, 30) - REF(EMA(CLOSE, 30), 1)) / REF(EMA(CLOSE, 30), 1) * 100 + DIF = EMA(SUM(MACD, 2), 5) + buy_1 = DIF > REF(DIF, 1) + buy_2 = DIF < REF(DIF, 1) + + DEA = MA(DIF, 5) + return pd.DataFrame( + {'code': df['ts_code'], 'date': d, 'MACD': MACD, 'DIF': DIF, 'DEA': DEA, 'buy1': buy_1, "buy2": buy_2}) + + +def DJCPX(df_data: pd.DataFrame): + """ + 顶级操盘线 指标 + :param df_data: + :return: + """ + df = df_data[::-1].reset_index(drop=True) + k = df["close"] + d = df["trade_date"] + # print(f'{d[i]} -->> {buy_1[i]} -->> {buy_2[i]} -->> {B[i]}') + VAR_200 = round((100 - ((90 * (HHV(df["high"], 20) - df["close"])) / ( + HHV(df["high"], 20) - LLV(df["low"], 20)))), 2) + VAR_300 = round((100 - MA( + ((100 * (HHV(df["high"], 5) - df["close"])) / (HHV(df["high"], 5) - LLV(df["low"], 5))), + 34)), 2) + VAR_300_MA_5 = MA(VAR_300, 5) + # F:IF(CROSS(VAR200,MA(VAR300,5)),LOW * 0.98,DRAWNULL),CROSSDOT,LINETHICK3,COLORFF00FF; + # CROSS 上穿函数 CROSS(A,B)表示当A从下方向上穿过B时返回1,否则返回0 + F = np.zeros(df.shape[0]) + VAR_CROSS = CROSS(VAR_200, VAR_300_MA_5) + for i in range(df.shape[0]): + if VAR_CROSS[i] == 1: + F[i] = round(df["low"][i] * 0.98, 2) + # 重心:=(C+0.618*REF(C,1)+0.382*REF(C,1)+0.236*REF(C,3)+0.146*REF(C,4))/2.382; + ZX = round((k + (0.618 * REF(k, 1)) + (0.382 * REF(k, 1)) + (0.236 * REF(k, 3)) + ( + 0.146 * REF(k, 4))) / 2.382, 2) + # 【操盘线】:EMA(((SLOPE(C,22)*20)+C),55),COLORYELLOW,LINETHICK4; + CPX = round(EMA(((SLOPE(k, 22) * 20) + k), 55), 2) + # 【黄金线】:IF(重心>=【操盘线】,【操盘线】,DRAWNULL),COLORRED,LINETHICK2; + HJX = np.zeros(df.shape[0]) + # 【空仓线】:IF(重心<【操盘线】,【操盘线】,DRAWNULL),COLORCYAN,LINETHICK2; + KCX = np.zeros(df.shape[0]) + for i in range(df.shape[0]): + if ZX[i] >= CPX[i]: + HJX[i] = CPX[i] + else: + KCX[i] = CPX[i] + return pd.DataFrame( + {'code': df['ts_code'], 'date': d, 'F': F, '黄金线': HJX, '空仓线': KCX}) + + +def CCI(DF, n: int = 14): + TP = (DF['low'] + DF['high'] + DF['close']) / 3 + MA = TP.rolling(window=n).mean() + MD = TP.rolling(window=n).apply(lambda x: abs(x - x.mean()).mean(), raw=False) + return round((TP - MA) / (0.015 * MD), 2) + + +def bullish(DF, N): + """ + 多头指标,N项的递增序列 + :param DF: + :param N: + :return: + """ + return pd.Series.rolling(DF, N).apply(lambda x: x.is_monotonic_increasing) + + +def bearish(DF, N): + """ + 空头指标,N项的递减序列 + :param DF: + :param N: + :return: + """ + return pd.Series.rolling(DF, N).apply(lambda x: x.is_monotonic_decreasing) + + +def OBV(c, v, M): + # diff 计算相邻元素的差值 + change = np.diff(c) + # sign 用于获取数组元素的符号的函数,对于正数,返回 1,对于负数,返回 -1,对于零,返回 0 + # hstack 用于水平(按列)连接数组的函数 + sig = np.hstack([[1], np.sign(change)]) + # cumsum 计算累积和的方法。它将给定数组中的元素逐个累加 + obv_ = np.cumsum(v * sig) + OBV = pd.Series(obv_) + MAOBV = MA(OBV, M) + return pd.DataFrame({'OBV': OBV, 'MAOBV': MAOBV}) + + +def OBV_PLUS(DF, M): + """ + OBV策略升级:TODO 还没完成 + 1. 增加价格相距大的那一天成交量的权重,这可以更突出上升趋势和下降趋势。 + 2. 当天的成交量以一定比例加入OBV中,而不是将全天的成交量全部加入OBV中。 + :param DF: + :param M: + :return: + """ + CLOSE = DF['close'] + VOL = DF['vol'] + ref = REF(CLOSE, 1) + var_total = 0 + for index, row in DF.iterrows(): + if np.isnan(ref[index]): + var_total += VOL[index] + continue + if row["close"] > ref[index]: + vol = VOL[index] * 1 + elif row["close"] == ref[index]: + vol = 0 + else: + vol = VOL[index] * -1 + var_total += vol + + +def ASI(OPEN, CLOSE, HIGH, LOW, M1=26, M2=10): + """ + # 振动升降指标 + :param OPEN: + :param CLOSE: + :param HIGH: + :param LOW: + :param M1: + :param M2: + :return: + """ + LC = REF(CLOSE, 1) + AA = ABS(HIGH - LC) + BB = ABS(LOW - LC) + CC = ABS(HIGH - REF(LOW, 1)) + DD = ABS(LC - REF(OPEN, 1)) + R = IF((AA > BB) & (AA > CC), AA + BB / 2 + DD / 4, IF((BB > CC) & (BB > AA), BB + AA / 2 + DD / 4, CC + DD / 4)) + X = (CLOSE - LC + (CLOSE - OPEN) / 2 + LC - REF(OPEN, 1)) + SI = 16 * X / R * MAX(AA, BB) + ASI = SUM(SI, M1) + ASIT = MA(ASI, M2) + return {'ASI': ASI, 'ASIT': ASIT} + + +def DMI(CLOSE, HIGH, LOW, M1=14, M2=6): # 动向指标:结果和同花顺,通达信完全一致 + TR = SUM(MAX(MAX(HIGH - LOW, ABS(HIGH - REF(CLOSE, 1))), ABS(LOW - REF(CLOSE, 1))), M1) + HD = HIGH - REF(HIGH, 1) + LD = REF(LOW, 1) - LOW + DMP = SUM(IF((HD > 0) & (HD > LD), HD, 0), M1) + DMM = SUM(IF((LD > 0) & (LD > HD), LD, 0), M1) + PDI = (DMP * 100) / TR + MDI = (DMM * 100) / TR + ADX = MA(ABS(MDI - PDI) / (PDI + MDI) * 100, M2) + ADXR = (ADX + REF(ADX, M2)) / 2 + return {'PDI': round(PDI.fillna(0), 2), 'MDI': round(MDI.fillna(0), 2), 'ADX': round(ADX.fillna(0), 2), + 'ADXR': round(ADXR.fillna(0), 2)} + + +def RM_KDJ(C, H, L, N, M1, M2): + RSV = (C - LLV(L, N)) / (HHV(H, N) - LLV(L, N)) * 100 + K = RM_SMA(RSV, M1, 1) + D = RM_SMA(K, M2, 1) + J = 3 * K - 2 * D + return pd.DataFrame({'KDJ_K': round(K, 2), 'KDJ_D': round(D, 2), 'KDJ_J': round(J, 2)}) + + +def INTPART(number): + number = number.fillna(0) + return number.astype(int) + + +def JXNH(CLOSE, OPEN, VOL): + VAR1 = MA(CLOSE, 5) + VAR2 = MA(CLOSE, 10) + VAR3 = MA(CLOSE, 30) + VARB = SUM(CLOSE * VOL * 100, 28) / SUM(VOL * 100, 28) + VARC = INTPART(VARB * 100) / 100 + VARD = EMA(CLOSE, 5) - EMA(CLOSE, 10) + VARE = EMA(VARD, 9) + VAR13 = REF(VARE, 1) + VAR14 = VARE + VAR15 = VAR14 - VAR13 + VAR16 = REF(VARD, 1) + VAR17 = VARD + VAR18 = VAR17 - VAR16 + VAR19 = OPEN + VAR1A = CLOSE + JXNH = (VAR19 <= VAR1) & \ + (VAR19 <= VAR2) & \ + (VAR19 <= VAR3) & \ + (VAR1A >= VAR1) & \ + (VAR1A >= VARC) & (VAR15 > 0) & (VAR18 > 0) + return pd.DataFrame({'JXNH': JXNH}) diff --git a/utils/tdxUtil.py b/utils/tdxUtil.py new file mode 100644 index 0000000..4929f94 --- /dev/null +++ b/utils/tdxUtil.py @@ -0,0 +1,176 @@ +import struct +from pathlib import Path + +import pandas as pd + +""" +通达信操作工具类 +""" + + +class TdxUtil: + SECURITY_EXCHANGE = ["sz", "sh"] + SECURITY_COEFFICIENT = {"SH_A_STOCK": [0.01, 0.01], "SH_B_STOCK": [0.001, 0.01], "SH_INDEX": [0.01, 1.0], + "SH_FUND": [0.001, 1.0], "SH_BOND": [0.001, 1.0], "SZ_A_STOCK": [0.01, 0.01], + "SZ_B_STOCK": [0.01, 0.01], "SZ_INDEX": [0.01, 1.0], "SZ_FUND": [0.001, 0.01], + "SZ_BOND": [0.001, 0.01]} + SECURITY_TYPE = ["SH_A_STOCK", "SH_B_STOCK", "SH_INDEX", "SH_FUND", "SH_BOND", "SZ_A_STOCK", "SZ_B_STOCK", + "SZ_INDEX", "SZ_FUND", "SZ_BOND"] + + def __init__(self, tdx_path: str): + self.f_name = None + self.path = tdx_path + # 通达信自选股路径 + self.zxg_path = tdx_path + "\\T0002\\blocknew\\ZXG.blk" + # 上证日线 + self.lday_sz_path = tdx_path + "\\vipdoc\\sz\\lday\\" + # 沪深日线 + self.lday_sh_path = tdx_path + "\\vipdoc\\sh\\lday\\" + + def set_zxg_file(self, cont: pd): + """ + 通达信自选股写入 + :param cont: + :return: + """ + # 将 DataFrame 内容逐行写入文本文件 + with open(self.zxg_path, "w") as file: + for item in cont: + stock_code = str(item) + code = stock_code.split(".")[0] + if len(code) > 1: + if stock_code.startswith(('50', '51', '60', '688', '73', '90', '110', '113', '132', '204', '78')): + code = '1' + stock_code + if stock_code.startswith(('00', '13', '18', '15', '16', '18', '20', '30', '39', '115', '1318')): + code = '0' + stock_code + file.write(code + '\n') + # 关闭文件 + file.close() + + def get_zxg_file(self): + """ + 读取通达信自选股 + :return: + """ + f = open(self.zxg_path, 'r') + z = f.read() + f.close() + return z + + def get_tdx_stock_data(self, code=None, freq=None, market=None): + """ + 获取K线数据(日线,1分钟线,5分钟线,30分钟线) + :param code: 股票代码 + :param freq: 分钟数据,1:代表1分钟,5:代表5分钟 + :param market: 上海证券交易所: sh, 深证证券交易所: sz, 北京证券交易所: bj + :return: + """ + # 获取全部的数据 + if code is None: + return self.get_all_stock() + if freq is None: + return self.get_day_k(code, market) + else: + return self.get_freq_k(code, market, freq) + + def get_all_stock(self): + """ + 获取全部股票 + :return: + """ + codes = [] + for market in self.SECURITY_EXCHANGE: + files = Path(f'\\vipdoc\\{market}\\lday\\') + for file in files.iterdir(): + if not file.is_file(): + continue + tdx_code = file.name.split(".")[0] + market = tdx_code[:2] + code = tdx_code[2:] + res = self.get_security_type(code=code, market=market) + filter_stock = ["SZ_A_STOCK", "SH_A_STOCK"] + if res not in filter_stock: + continue + codes.append(code) + return codes + + def get_day_k(self, code, market): + security_type = self.get_security_type(code, market) + if security_type not in self.SECURITY_TYPE: + print("Unknown security type !\n") + raise NotImplementedError + if security_type == "SZ_A_STOCK": + self.f_name = f"{self.lday_sz_path}sz{code}.day" + if security_type == "SH_A_STOCK": + self.f_name = f"{self.lday_sh_path}sh{code}.day" + print(self.f_name) + df = pd.DataFrame(columns=['trade_date', 'open', 'high', 'low', 'close', 'amount', 'volume']) + file_path = Path(self.f_name) + coefficient = [0.01, 0.01] + if not file_path.is_file(): + raise f"{self.f_name} 不是一个文件!" + content = file_path.read_bytes() + record_struct = struct.Struct('