重构部分内容

This commit is contained in:
rm 2023-11-24 10:49:22 +08:00
parent 786723438e
commit e2db3cf16e
19 changed files with 2038 additions and 0 deletions

0
DB/__init__.py Normal file
View File

22
DB/db_config.py Normal file
View File

@ -0,0 +1,22 @@
mysql = {
"url": "127.0.0.1",
"port": "3306",
"username": "root",
"password": "qeadzc123",
"database": "qnloft_hospital",
}
# 基本数据
stock_info_db = "stock_info.db"
# 板块相关数据
stock_sector_db = "stock_sector.db"
# 个股日线数据
stock_daily_db = "stock_daily.db"
# 分钟线
stock_daily_freq_db = "stock_daily_freq.db"
# redis_info = ['10.10.XXX', XXX]
# mongo_info = ['10.10.XXX', XXX]
# es_info = ['10.10.XXX', XXX]
# sqlserver_info = ['10.10.XXX:XXX', 'XXX', 'XXX', 'XXX']
# db2_info = ['10.10.XXX', XXX, 'XXX', 'XXX']
# postgre_info = ['XXX', 'XXX', 'XXX', '10.10.XXX', XXX]
# ck_info = ['10.10.XXX', XXX]

271
DB/db_main.py Normal file
View File

@ -0,0 +1,271 @@
import json
import traceback
import pandas
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker, scoped_session, aliased, declarative_base, clear_mappers
from sqlalchemy import *
from DB import db_config as config
class DbMain:
def __init__(self):
# clear_mappers()
self.inspector = None
self.session = None
self.engine = None
def get_session(self):
# 绑定引擎
session_factory = sessionmaker(bind=self.engine)
# 创建数据库链接池直接使用session即可为当前线程拿出一个链接对象conn
# 内部会采用threading.local进行隔离
Session = scoped_session(session_factory)
return Session()
# ===================== insert or update方法=============================
def insert_all_entry(self, entries):
try:
self.create_table(entries)
self.session.add_all(entries)
self.session.commit()
except Exception as e:
print(e)
finally:
self.session.close()
def insert_entry(self, entry):
try:
self.create_table(entry)
self.session.add(entry)
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误")
finally:
self.session.close()
def insert_or_update(self, entry, query_conditions):
"""
insert_or_update 的是用需要在 对象中新增to_dict方法将需要更新的字段转成字典
:param entry:
:param query_conditions:
:return:
"""
try:
self.create_table(entry)
conditions = " AND ".join(f"{key} = '{value}'" for key, value in query_conditions.items())
select_sql = text(f"select count(1) from {entry.__tablename__} where 1=1 and {conditions}")
result = self.session.execute(select_sql)
# 如果查询有结果则执行update操作
if result.scalar() > 0:
if hasattr(entry, 'to_dict'):
formatted_attributes = ", ".join([f'{attr} = "{value}"' for attr, value in entry.to_dict().items()])
update_sql = text(f"UPDATE {entry.__tablename__} SET {formatted_attributes} WHERE {conditions}")
self.session.execute(update_sql)
else:
raise Exception("对象 不包含 to_dict 方法,请添加!")
else:
# 执行新增错做
self.insert_entry(entry)
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.expire_all()
self.session.close()
def create_table(self, entries):
if isinstance(entries, list):
table_name = entries[0].__tablename__
else:
table_name = entries.__tablename__
# 检查表是否存在,如果不存在则创建
if not self.inspector.has_table(table_name):
if isinstance(entries, list):
entries[0].metadata.create_all(self.engine)
else:
entries.metadata.create_all(self.engine)
def pandas_insert(self, data=pandas, table_name=""):
"""
新增数据操作类型是pandas
:param data:
:param table_name:
:return:
"""
try:
data.to_sql(table_name, con=self.engine.connect(), if_exists='append', index=True, index_label='id')
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
# ===================== select 方法=============================
def query_by_id(self, model, db_id):
try:
return self.session.query(model).filter_by(id=db_id).all()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def pandas_query_by_model(self, model, order_col=None, page_number=None, page_size=None):
"""
使用pandas的sql引擎执行
:param page_size:每页的记录数
:param page_number:第N页
:param model:
:param order_col:
:return:
"""
try:
# 判断表是否存在
if self.has_table(model):
query = self.session.query(model)
if order_col is not None:
query = query.order_by(order_col)
if page_number is not None and page_size is not None:
offset = (page_number - 1) * page_size
query = query.offset(offset).limit(page_size)
return pandas.read_sql_query(query.statement, self.engine.connect())
return pandas.DataFrame()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
def pandas_query_by_sql(self, stmt=""):
"""
使用pandas的sql引擎执行
:param stmt:
:return:
"""
try:
return pandas.read_sql_query(sql=stmt, con=self.engine.connect())
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
def pandas_query_by_condition(self, model, query_condition):
try:
# 当需要根据多个条件进行查询操作时
# query_condition = and_(
# StockDaily.trade_date == '20230823',
# StockDaily.symbol == 'ABC'
# )
query = self.session.query(model).filter(query_condition).order_by(model.id)
return self.pandas_query_by_sql(stmt=query.statement).reset_index()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
# ===================== delete 方法=============================
def delete_by_id(self, model, db_id):
try:
# 使用 delete() 方法删除符合条件的记录
self.session.query(model).filter_by(id=db_id).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def delete_by_condition(self, model, delete_condition):
try:
# 使用 delete() 方法删除符合条件的记录
# 定义要删除的记录的条件
# 例如,假设你要删除 trade_date 为 '20230823' 的记录
# delete_condition = StockDaily.trade_date == '20230823'
# 当需要根据多个条件进行删除操作时
# delete_condition = and_(
# StockDaily.trade_date == '20230823',
# StockDaily.symbol == 'ABC'
# )
self.session.query(model).filter(delete_condition).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def delete_all_table(self, model): # 清空表数据
try:
self.session.query(model).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
# ===================== 其它 方法=============================
def has_table(self, entries):
if isinstance(entries, list):
table_name = entries[0].__tablename__
else:
table_name = entries.__tablename__
# 检查表是否存在,如果不存在则创建
return self.inspector.has_table(table_name)
def execute_sql(self, s):
try:
sql_text = text(s)
return self.session.execute(sql_text)
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def execute_sql_to_pandas(self, s):
try:
sql_text = text(s)
res = self.session.execute(sql_text)
return pandas.DataFrame(res.fetchall(), columns=res.keys())
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def close(self):
self.session.close()

20
DB/model/ConceptSector.py Normal file
View File

@ -0,0 +1,20 @@
from sqlalchemy import Column, Integer, String, Index
from sqlalchemy.orm import declarative_base
'''
概念板块数据表
'''
class ConceptSector(declarative_base()):
__tablename__ = 'stock_concept_sector'
id = Column(Integer, primary_key=True, autoincrement=True)
trade_date = Column(String, nullable=False)
sector_name = Column(String)
sector_code = Column(String)
pct_change = Column(String)
total_market = Column(String)
rising_count = Column(String)
falling_count = Column(String)
new_price = Column(String)
leading_stock = Column(String)
leading_stock_pct_change = Column(String)

View File

@ -0,0 +1,27 @@
# 定义表格映射类
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base
'''
行业板块数据表
'''
class IndustrySector(declarative_base()):
__tablename__ = 'stock_industry_sector'
id = Column(Integer, primary_key=True, autoincrement=True)
trade_date = Column(String, nullable=False)
sector_name = Column(String)
pct_change = Column(String)
total_volume = Column(String)
total_turnover = Column(String)
net_inflows = Column(String)
rising_count = Column(String)
falling_count = Column(String)
average_price = Column(String)
leading_stock = Column(String)
leading_stock_latest_price = Column(String)
leading_stock_pct_change = Column(String)
# 创建索引
# trade_date_index = Index('idx_trade_date', IndustrySector.trade_date)
# sector_name_index = Index('idx_sector_name', IndustrySector.sector_name)

42
DB/model/StockBasic.py Normal file
View File

@ -0,0 +1,42 @@
# 定义表格映射类
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base
# 创建 StockBasic 映射类
class StockBasic(declarative_base()):
__tablename__ = 'stock_basic'
id = Column(Integer, primary_key=True, autoincrement=True)
ts_code = Column(String, nullable=False, unique=True)
symbol = Column(String, nullable=False)
name = Column(String)
comp_name = Column(String)
comp_name_en = Column(String)
isin_code = Column(String)
exchange = Column(String)
list_board = Column(String)
list_date = Column(String)
delist_date = Column(String)
crncy_code = Column(String)
pinyin = Column(String)
list_board_name = Column(String)
is_shsc = Column(String)
comp_code = Column(String)
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
return f"<StockBasic(ts_code='{self.ts_code}', name='{self.name}', exchange='{self.exchange}')>"
def to_dict(self):
# 定义要保留的属性列表
allowed_attributes = ['delist_date', 'is_shsc', 'comp_name', 'comp_name_en']
# 创建字典并添加需要保留的属性
obj_dict = {}
for attr in allowed_attributes:
obj_dict[attr] = getattr(self, attr)
return obj_dict

View File

@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base
'''
概念板块个股
'''
class StockByConceptSector(declarative_base()):
__tablename__ = 'stock_by_concept_sector'
id = Column(Integer, primary_key=True, autoincrement=True)
update_date = Column(String, nullable=False)
sector_name = Column(String)
sector_code = Column(String)
stock_name = Column(String)
stock_code = Column(String)

View File

@ -0,0 +1,15 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base
'''
行业板块个股
'''
class StockByIndustrySector(declarative_base()):
__tablename__ = 'stock_by_industry_sector'
id = Column(Integer, primary_key=True, autoincrement=True)
update_date = Column(String, nullable=False)
sector_name = Column(String)
sector_code = Column(String)
stock_name = Column(String)
stock_code = Column(String)

58
DB/model/StockDaily.py Normal file
View File

@ -0,0 +1,58 @@
# 定义表格映射类
# 定义表格映射类
from sqlalchemy import Column, Integer, String, Float
from sqlalchemy.orm import declarative_base
def get_stock_daily(table_name):
class StockDaily(declarative_base()):
__tablename__ = table_name
id = Column(Integer, primary_key=True, autoincrement=True)
ts_code = Column(String, nullable=False) # ts代码
trade_date = Column(String, nullable=False) # 交易日期
crncy_code = Column(String) # 货币代码
pre_close = Column(Float) # 昨收盘价(元)
open = Column(Float) # 开盘价(元)
high = Column(Float) # 最高价(元)
low = Column(Float) # 最低价(元)
close = Column(Float) # 收盘价(元)
change = Column(Float) # 涨跌(元)
pct_chg = Column(Float) # 涨跌幅(%)
volume = Column(Float) # 成交量(手)
amount = Column(Float) # 成交金额(千元)
adj_pre_close = Column(Float) # 复权昨收盘价(元)
adj_open = Column(Float) # 复权开盘价(元)
adj_high = Column(Float) # 复权最高价(元)
adj_low = Column(Float) # 复权最低价(元)
adj_close = Column(Float) # 复权收盘价(元)
adj_factor = Column(Float) # 复权因子
avg_price = Column(Float) # 均价(VWAP)
trade_status = Column(String) # 交易状态
turnover_rate = Column(Float) # 换手率
turnover_rate_f = Column(Float) # 换手率(自由流通股)
volume_ratio = Column(Float) # 量比
pe = Column(Float) # 市盈率(总市值/净利润)
pe_ttm = Column(Float) # 市盈率TTM
pb = Column(Float) # 市净率(总市值/净资产)
ps = Column(Float) # 市销率
ps_ttm = Column(Float) # 市销率TTM
dv_ratio = Column(Float) # 股息率
dv_ttm = Column(Float) # 滚动股息率
total_share = Column(Float) # 总股本
float_share = Column(Float) # 流通股本
free_share = Column(Float) # 自由流通股本
total_mv = Column(Float) # 总市值
circ_mv = Column(Float) # 流通市值
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
return f"<StockDaily(ts_code='{self.ts_code}', trade_date='{self.trade_date}', close='{self.close}')>"
def to_dict(self):
return {"ts_code": self.ts_code, "trade_date": self.trade_date, "pre_close": self.pre_close,
"turnover_rate": self.turnover_rate}
return StockDaily

View File

@ -0,0 +1,26 @@
from sqlalchemy import Column, Integer, String, Float
from sqlalchemy.orm import declarative_base
def get_stock_daily_freq(table_name):
class StockDailyFreq(declarative_base()):
__tablename__ = table_name
id = Column(Integer, primary_key=True)
trade_date = Column(String) # 交易日期
time = Column(String) # 时间
open = Column(Float) # 开盘价
close = Column(Float) # 收盘价
high = Column(Float) # 最高价
low = Column(Float) # 最低价
vol = Column(Float) # 成交量(注意单位:手)
amount = Column(Float) # 成交额
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def to_dict(self):
return {"trade_date": self.trade_date, "time": self.time,
"close": self.close}
return StockDailyFreq

0
DB/model/__init__.py Normal file
View File

26
DB/mysql_db_main.py Normal file
View File

@ -0,0 +1,26 @@
from DB.db_main import *
def create_mysql_engine():
# 创建引擎
return create_engine(
f"mysql+pymysql://{config.mysql['username']}:{config.mysql['password']}@{config.mysql['url']}:{config.mysql['port']}/{config.mysql['database']}?charset=utf8mb4",
# "mysql+pymysql://tom@127.0.0.1:3306/db1?charset=utf8mb4", # 无密码时
# 超过链接池大小外最多创建的链接
max_overflow=0,
# 链接池大小
pool_size=5,
# 链接池中没有可用链接则最多等待的秒数,超过该秒数后报错
pool_timeout=10,
# 多久之后对链接池中的链接进行一次回收
pool_recycle=1,
# 查看原生语句(未格式化)
echo=True
)
class MysqlDbMain(DbMain):
def __init__(self):
super().__init__()
self.engine = create_mysql_engine()
self.session = self.get_session()

50
DB/sqlite_db_main.py Normal file
View File

@ -0,0 +1,50 @@
from pathlib import Path
from DB.db_main import *
import platform
import logging
class SqliteDbMain(DbMain):
def __init__(self, database_name):
# logging.basicConfig()
# logging.getLogger('sqlalchemy.engine').setLevel(logging.ERROR)
# logging.getLogger('sqlalchemy.pool').setLevel(logging.ERROR)
# logging.getLogger('sqlalchemy.dialects').setLevel(logging.ERROR)
# logging.getLogger('sqlalchemy.orm').setLevel(logging.ERROR)
self.database_name = database_name
super().__init__()
self.engine = self.__create_sqlite_engine()
self.engine_path = self.__get_path()
self.session = self.get_session()
self.inspector = inspect(self.engine)
def __get_path(self):
sys_platform = platform.platform().lower()
print(f'当前操作系统:{platform.platform()}')
__engine = ''
if 'windows' in sys_platform.lower():
__engine = f"E:\\sqlite_db\\stock_db\\{self.database_name}"
elif 'macos' in sys_platform.lower():
__engine = f"/Users/renmeng/Documents/sqlite_db/{self.database_name}"
else:
__engine = f"{self.database_name}"
return __engine
def __create_sqlite_engine(self):
sys_platform = platform.platform().lower()
__engine = ''
if 'windows' in sys_platform.lower():
__engine = f"sqlite:///E:\\sqlite_db\\stock_db\\{self.database_name}"
elif 'macos' in sys_platform.lower():
__engine = f"sqlite:////Users/renmeng/Documents/sqlite_db/{self.database_name}"
else:
__engine = f"sqlite:///{self.database_name}"
print(f"当前__engine是{__engine}")
return create_engine(__engine, pool_size=10, pool_timeout=10, echo=True)
def get_db_size(self):
file_size = Path(self.engine_path).stat().st_size
total = f"{file_size / (1024 * 1024):.2f} MB"
print(f"文件大小: {file_size / (1024 * 1024):.2f} MB")
return total

119
fp/分钟线数据入库.py Normal file
View File

@ -0,0 +1,119 @@
from datetime import timedelta
from DB.model.StockDailyFreq import get_stock_daily_freq
from DB.sqlite_db_main import SqliteDbMain, config
from utils.tdxUtil import TdxUtil
from utils.comm import *
from 基本信息入库 import StockInfoMain
class StockDailyFreqMain:
def __init__(self):
self.code_res = StockInfoMain().get_stock_basic()
self.db_main = SqliteDbMain(config.stock_daily_freq_db)
def __filter_stock(self, code, name):
if 'ST' in name:
return False
if code.startswith('30'):
return False
if code.startswith('68'):
return False
return True
def init_data(self):
"""
全量入库操作
:return:
"""
# Print the results
for result in self.code_res:
tdx_util = TdxUtil("")
s_type = tdx_util.get_security_type(code=result.ts_code)
if s_type in tdx_util.SECURITY_TYPE and self.__filter_stock(result.ts_code, result.name):
# 设置开始和结束时间
start_time = f"{(datetime.now() - timedelta(days=365 * 3)).strftime('%Y-%m-%d')} 09:30:00"
end_time = f"{datetime.now().strftime('%Y-%m-%d')} 15:00:00"
print(
f"{result.id}{result.ts_code} --> start_time={trade_date} 09:30:00 --->> end_time={trade_date} 15:00:00")
i = 0
df = None
while True:
try:
if i > 6:
break
df = xcsc_pro.stk_mins(ts_code=result.ts_code, freq='30min',
start_time=start_time,
end_time=end_time)[::-1]
except Exception as e:
i += 1
# 捕获超时异常
print("请求超时等待2分钟后重试...")
time.sleep(120) # 休眠2分钟
print(f"{result.ts_code}{result.name} 获取 stk_mins _ 30 数据完毕!")
# 创建表
table_name = str(result.ts_code).split(".")[0] + "_daily_freq_30"
new_table_class = get_stock_daily_freq(table_name=table_name)
self.db_main.create_table(new_table_class)
entries = []
for index, row in df.iterrows():
min_time = row['trade_time']
# 将字符串解析为日期时间对象
trade_time = datetime.strptime(min_time, '%Y-%m-%d %H:%M:%S').strftime('%Y-%m-%d')
entry = new_table_class(
trade_date=trade_time,
time=min_time,
open=row['open'],
close=row['close'],
high=row['high'],
low=row['low'],
vol=row['vol'],
amount=row['amount'],
)
print(entry.to_dict())
entries.append(entry)
self.db_main.insert_all_entry(entries)
def task_data(self, trade_date=datetime.now().strftime('%Y%m%d')):
print(len(self.code_res))
for result in self.code_res:
tdx_util = TdxUtil("")
s_type = tdx_util.get_security_type(code=result.ts_code)
if s_type in tdx_util.SECURITY_TYPE and self.__filter_stock(result.ts_code, result.name):
print(
f"{result.id}{result.ts_code} --> start_time={trade_date} 09:30:00 --->> end_time={trade_date} 15:00:00")
df = xcsc_pro.stk_mins(ts_code=result.ts_code, freq='30min',
start_time=f"{trade_date} 09:30:00",
end_time=f"{trade_date} 15:00:00")
print(f"{result.ts_code},{result.name} 获取 stk_mins _ 30 数据完毕!")
for index, row in df.iterrows():
table_name = str(result.ts_code).split(".")[0] + "_daily_freq_30"
new_table_class = get_stock_daily_freq(table_name=table_name)
entry = new_table_class(
trade_date=trade_date,
time=row['trade_time'],
open=row['open'],
close=row['close'],
high=row['high'],
low=row['low'],
vol=row['vol'],
amount=row['amount'],
)
print(entry.to_dict())
self.db_main.insert_entry(entry)
if __name__ == '__main__':
current_date = datetime.now()
if if_run(current_date):
trade_date = current_date.strftime('%Y-%m-%d')
print(trade_date)
main = StockDailyFreqMain()
# main.init_data()
main.task_data(trade_date)
# df = xcsc_pro.stk_mins(ts_code='600095.SH', start_time='2022-08-24 09:30:00', end_time='2023-08-24 15:00:00',
# freq='30min')
# print(df)

71
fp/基本信息入库.py Normal file
View File

@ -0,0 +1,71 @@
from sqlalchemy import and_
from DB.model.StockBasic import StockBasic
from DB.sqlite_db_main import SqliteDbMain, config
from utils.comm import *
class StockInfoMain:
def __init__(self):
self.db_main = SqliteDbMain(config.stock_info_db)
self.session = self.db_main.get_session()
def insert_stock_basic(self):
pro = xcsc_pro
info = ['SZSE', 'SSE']
# entries = []
for item in info:
# SSE: 上交所 SZSE: 深交所
df = pro.stock_basic(exchange=item)
for index, row in df.iterrows():
entry = StockBasic(**row)
self.db_main.insert_or_update(entry, query_conditions={'ts_code': row['ts_code']})
def get_stock_basic(self, ts_code=None, symbol=None, restart_id=0):
sql_cond = and_(
StockBasic.delist_date.is_('None')
)
try:
if restart_id > 0:
sql_cond = and_(
StockBasic.id >= restart_id,
StockBasic.delist_date.is_('None')
)
if ts_code is not None:
if isinstance(ts_code, str):
sql_cond = and_(
StockBasic.ts_code.is_(ts_code),
StockBasic.delist_date.is_('None')
)
elif isinstance(ts_code, list):
sql_cond = and_(
StockBasic.ts_code.in_(ts_code),
StockBasic.delist_date.is_('None')
)
elif symbol is not None:
if isinstance(symbol, str):
sql_cond = and_(
StockBasic.symbol.is_(symbol),
StockBasic.delist_date.is_('None')
)
elif isinstance(symbol, list):
sql_cond = and_(
StockBasic.ts_code.in_(symbol),
StockBasic.delist_date.is_('None')
)
results = self.session.query(StockBasic).filter(sql_cond).all()
return results
finally:
self.session.close()
# try:
# results = self.session.query(StockBasic).filter(StockBasic.delist_date.is_(None)).all()
# return results
# finally:
# self.session.close()
if __name__ == '__main__':
print(StockInfoMain().get_stock_basic(restart_id=5))

124
fp/新闻资讯.py Normal file
View File

@ -0,0 +1,124 @@
from utils.comm import *
from snownlp import SnowNLP
import jieba
import jieba.posseg as pseg
import jieba.analyse
class News:
def __init__(self, trade_date=datetime.now()):
self.trade_date = trade_date.strftime('%Y%m%d')
@staticmethod
def preprocess(text):
seg_list = jieba.lcut(text) # 文本预处理及分词
return seg_list
@staticmethod
def sentiment_analysis(text): # 情感分析
s = SnowNLP(text)
sentiment = s.sentiments # 情感得分范围0-1越接近1表示正面情感越接近0表示负面情感
return sentiment
@staticmethod
def extract_keywords(text, top_k=5): # 关键词提取
keywords = jieba.analyse.extract_tags(text, topK=top_k, withWeight=False, allowPOS=()) # 提取关键词默认返回前5个
return keywords
@staticmethod
def named_entity_recognition(text): # 命名实体识别
words = pseg.cut(text)
entities = []
for word, flag in words:
# 识别人名、地名、机构名、其它专名
if flag in ['nr', 'ns', 'nt', 'nz']:
entities.append((word, flag))
return entities
def result(self, news):
print("原文:", news)
preprocessed_news = self.preprocess(news)
sentiment_score = self.sentiment_analysis(news)
keywords = self.extract_keywords(news, top_k=3)
entities = self.named_entity_recognition(news)
print("分词结果:", preprocessed_news)
print("情感分析得分:", round(sentiment_score, 2))
print("关键词提取:", keywords)
print("命名实体识别:", entities)
def news_cctv(self):
for _ in range(5):
try:
# 中央新闻
df = ak.news_cctv(date=self.trade_date)
# 情感分析
# df['title'].apply(lambda x: self.result(x))
# print(df['情感'])
return df
except Exception as e:
print(f"{self.trade_date} 日的新闻咨询拉取发生错误!{e.__traceback__}")
time.sleep(1)
continue
else:
print("中央新闻 5次出现错误请关注")
return None
def news_cls(self):
for _ in range(5):
try:
# 财联社-电报
df = ak.stock_telegraph_cls()
# 情感分析
# df['title'].apply(lambda x: self.result(x))
# print(df['情感'])
return df
except Exception as e:
print(f"{self.trade_date} 财联社-电报 咨询拉取发生错误!{e.__traceback__}")
continue
else:
print("财联社-电报 5次出现错误请关注")
return None
@staticmethod
def news_stock_by_code(symbol, date_time=None):
# 个股新闻
stock_news_em_df = ak.stock_news_em(symbol=symbol)
# 将发布时间列转换为日期时间类型
stock_news_em_df['发布时间'] = pd.to_datetime(stock_news_em_df['发布时间'])
# print(stock_news_em_df)
if date_time is not None:
if len(date_time.split("-")) == 1:
str_date = datetime.strptime(str(date_time), '%Y%m%d').strftime('%Y-%m-%d')
else:
str_date = datetime.strptime(str(date_time), '%Y-%m-%d').strftime('%Y-%m-%d')
else:
# 获取今天的日期
str_date = datetime.now().date()
# 使用日期过滤器筛选出今天的新闻
today_news = stock_news_em_df[stock_news_em_df['发布时间'].dt.date == str_date]
# 打印今天的新闻
print(today_news['新闻标题'])
print(today_news["新闻内容"])
# today_news['新闻标题'].apply(lambda x: result(x))
return today_news
def html_page_data(self):
cctv_df = self.news_cctv()
cls_df = self.news_cls()
cctv_content, cls_content = '', ''
for index, row in cctv_df.iterrows():
cctv_content += f'<tr><th scope="row">{row["title"]}</th>' \
f'<td>开发中...</td>'
cctv_content += '</tr>'
for index, row in cls_df.iterrows():
if len(row["标题"]) > 1:
cls_content += f'<tr><th scope="row">{row["标题"]}</th>' \
f'<td>开发中...</td>'
cls_content += '</tr>'
return cctv_content, cls_content
if __name__ == '__main__':
date_obj = datetime.strptime('20231019', "%Y%m%d")
print(News().news_cctv())

285
utils/comm.py Normal file
View File

@ -0,0 +1,285 @@
import ftplib
import os
from datetime import datetime
from ftplib import FTP, error_perm
from pathlib import Path
import akshare as ak
import efinance as ef
import qstock as qs
import tushare as ts
import pandas as pd
import time
import random
from tabulate import tabulate
import platform
import xcsc_tushare as xc
token = '0718534658b9d91b3f03dc8b220e4062193ebf6f6414d036505165e1'
xc.set_token('bd4f26f1eca8d660bd23e229260df46002d630d6e1fb9226380edec8')
# 仿真环境用这个方式可以连上,生产环境用这个连不上
xcsc_pro = xc.pro_api(env='prd', server='http://116.128.206.39:7172')
# 显示所有列
pd.set_option('display.max_columns', None)
# 显示所有行
pd.set_option('display.max_rows', None)
# 输出不折行
pd.set_option('expand_frame_repr', False)
# 最大列宽度
pd.set_option('display.max_colwidth', None)
def sleep():
print("开始休眠防止ip被拉黑或者进小黑屋")
# 生成随机的休眠时间
sleep_time = random.uniform(0, 2)
# 执行休眠
time.sleep(sleep_time)
print("休眠结束!")
def print_markdown(df: pd):
print(tabulate(df, headers='keys', tablefmt='github'))
def get_file(size):
"""
获取不同系统下的文件路径
:param size:
:return:
"""
sys_platform = platform.platform().lower()
if "macos" in sys_platform:
return f'/Users/renmeng/work_space/python_work/qnloft-get-web-everything/股票金融/量化交易/股票数据/{size}'
elif "windows" in sys_platform:
return f'E:\\Python_Workplace\qnloft-get-web-everything\\股票金融\\量化交易\\股票数据\\{size}'
else:
print("其他系统")
def write_file(content, file_name, mode_type='a'):
file = ""
sys_platform = platform.platform().lower()
if "macos" in sys_platform:
file = f'/Users/renmeng/work_space/python_work/qnloft-get-web-everything/股票金融/量化交易/股票数据/{file_name}'
elif "windows" in sys_platform:
file = f'D:\\文档\\数据测试\\{file_name}'
else:
print("其他系统")
path = Path(file)
# 创建目录或文件
if not path.exists():
if path.is_dir():
path.mkdir(parents=True) # 创建目录及其父目录
else:
path.touch() # 创建文件
with path.open(mode=mode_type) as file:
file.write(content)
def create_file(file_name):
path = Path(file_name)
# 创建目录或文件
if not path.exists():
path.mkdir(parents=True) # 创建目录及其父目录
def get_random_stock(size, n=0):
if size == 'M':
s_list = Path(get_file(size)).read_text().splitlines()
elif size == 'S':
s_list = Path(get_file(size)).read_text().splitlines()
elif size == 'L':
s_list = Path(get_file(size)).read_text().splitlines()
else:
raise ValueError("size输入错误")
if n > 0:
return random.sample(s_list, n)
else:
return s_list
def del_file_lines(size, to_delete):
content = Path(get_file(size)).read_text()
# 删除匹配的数字
for char in to_delete:
content = content.replace(char, '')
# 将更新后的内容写回文件
with Path(get_file(size)).open(mode='w') as file:
file.writelines(content)
def if_run(current_date=datetime.now()):
"""
判断当前日期是否是交易日
:param current_date:
:return:
"""
df = ak.tool_trade_date_hist_sina()
# 将日期列转换为 datetime 类型
df['trade_date'] = pd.to_datetime(df['trade_date'])
date_now_hour = current_date.hour
# 将日期格式化为 "yyyy-mm-dd" 形式
formatted_date = current_date.strftime('%Y-%m-%d')
# 检查 DataFrame 是否包含特定日期
contains_target_date = (df['trade_date'] == formatted_date).any()
print(f'当日日期是:{formatted_date} ,今日是否开盘:{contains_target_date}')
if contains_target_date:
return True
return False
def return_trading_day(now_date, offset_date, format_type='%Y%m%d'):
"""
获取当前时间的 前n个交易日期后n个交易日期
:param format_type:
:param offset_date:
:param now_date:
:return:
"""
df = ak.tool_trade_date_hist_sina()
# 将日期列转换为 datetime 类型
df['trade_date'] = pd.to_datetime(df['trade_date'])
if isinstance(now_date, datetime):
formatted_date = now_date.strftime('%Y-%m-%d')
else:
formatted_date = now_date
selected_indexes = df[df['trade_date'] == formatted_date].index
return df.loc[selected_indexes + offset_date, 'trade_date'].dt.strftime(format_type).values[0]
def upload_files_to_ftp(local_directory, remote_directory, file_name=None):
"""
ftp上传
:param local_directory:
:param remote_directory:
:param file_name:
:return:
"""
try:
ftp = FTP("qxu1142200198.my3w.com")
ftp.login('qxu1142200198', '48XZ55MB')
except error_perm as e:
print(f"Network error: {e}")
return False
try:
# 切换目录
ftp.cwd(remote_directory)
except ftplib.error_perm:
print("目录不存在, 开始创建...")
ftp.mkd(remote_directory)
try:
path = Path(local_directory)
# 判断是文件夹还是文件
if path.is_dir() and file_name is None:
files = os.listdir(local_directory) # get list of files in local directory
for file in files:
if file.endswith(".html"): # upload only .txt files
local_file = f"{local_directory}/{file}"
remote_file = f"{remote_directory}/{file}"
print(local_file, "----->>", remote_file)
with open(local_file, 'rb') as f:
try:
ftp.delete(remote_file)
except ftplib.error_perm:
print("目录不存在, 或已经删除...")
ftp.storbinary(f"STOR {remote_file}", f) # upload file to FTP server
if file_name is not None and len(file_name) > 0:
local_file = f"{local_directory}/{file_name}"
remote_file = f"{remote_directory}/{file_name}"
print(local_file, "----->>", remote_file)
with open(local_file, 'rb') as f:
try:
ftp.delete(remote_file)
except ftplib.error_perm:
print("目录不存在, 或已经删除...")
ftp.storbinary(f"STOR {remote_file}", f) # upload file to FTP server
finally:
ftp.quit()
return True
def back_testing(df, cond: str, n=3):
"""
回撤测试
:param df:
:param cond:
:param n:
:return:
"""
filtered_df = df.eval(cond)
total, tomorrow_rise, tomorrow_fall = 0, 0, 0
res_df = pd.DataFrame(
columns=['code', '日期', '明天涨跌幅', 'n日最大涨幅', 'n日平均涨跌幅', 'n日最大回撤'])
for index, row in df.iterrows():
if filtered_df[index] and index > n:
# print(f'{row["trade_date"]} --> {row["KDJ_D"]} --> {row["KDJ_J"]}')
total += 1
# 明天涨跌幅情况
tomorrow_pre = df.iloc[-index:-index + 1]['pct_chg'].values[0].round(2)
# n日 最大涨幅
max_h = df.iloc[-index:-index + n]['pct_chg'].max().round(2)
# n日 平均涨跌幅
mean_h = df.iloc[-index:-index + n]['pct_chg'].mean().round(2)
# n日 最大回撤
h_max = df.iloc[-index:-index + n]['high'].max()
l_min = df.iloc[-index:-index + n]['low'].min()
max_ret = f"{(h_max - l_min) / l_min:.2f}"
res_df.loc[total, 'code'], res_df.loc[total, '日期'] = row['ts_code'], row["trade_date"]
res_df.loc[total, '明天涨跌幅'] = tomorrow_pre
res_df.loc[total, 'n日最大涨幅'] = max_h
res_df.loc[total, 'n日平均涨跌幅'] = mean_h
res_df.loc[total, 'n日最大回撤'] = max_ret
return res_df
def buying_and_selling_decisions(df, cond_buy: str, cond_sell: str, cond_buy_read=None, cond_sell_read=None):
"""
买卖测试
:param df:
:param cond_buy:
:param cond_sell:
:param cond_buy_read:
:param cond_sell_read:
:return:
"""
total, buy, sell, total_chg, buy_date, flag_ready_b, flag_ready_s = 0, 0, 0, 0, 0, False, False
res_df = pd.DataFrame(
columns=['code', '买入日期', '卖出日期', '盈利', '持仓天数'])
for index, row in df.iterrows():
trade_date, c = row["trade_date"], row['close']
if (eval(cond_buy_read) or cond_buy_read is None) and buy == 0:
# 准备买入
print(f"{trade_date} {cond_buy_read},准备买入!")
flag_ready_b = True
if eval(cond_buy) and buy == 0 and flag_ready_b:
# 买入
print(f"{trade_date} {cond_buy},买入操作!")
buy, buy_date = c, trade_date
flag_ready_b = False
if (eval(cond_sell_read) or cond_sell_read is None) and buy > 0:
# 准备卖出
print(f"{trade_date} {cond_sell_read},准备卖出!")
flag_ready_s = True
if eval(cond_sell) and buy > 0 and flag_ready_s:
print(f"{trade_date} {cond_sell},卖出操作!")
total += 1
# 卖出
sell = c
# 计算盈利
chg = round(((sell - buy) / buy) * 100, 2)
# 计算持仓天数
start_date = datetime.strptime(buy_date, "%Y%m%d")
end_date = datetime.strptime(trade_date, "%Y%m%d")
# 计算间隔天数
interval_days = (end_date - start_date).days
buy, flag_ready_s = 0, False
res_df.loc[total, 'code'], res_df.loc[total, '买入日期'] = row['ts_code'], buy_date
res_df.loc[total, '卖出日期'] = trade_date
res_df.loc[total, '盈利'] = chg
res_df.loc[total, '持仓天数'] = interval_days
return res_df

691
utils/formula.py Normal file
View File

@ -0,0 +1,691 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
def EMA(number, n):
return pd.Series(number).ewm(alpha=2 / (n + 1), adjust=True).mean()
def MA(number, n):
return pd.Series.rolling(number, n).mean()
def SMA(number, n, m=1):
_df = number.fillna(0)
return pd.Series(_df).ewm(com=n - m, adjust=True).mean()
def RM_SMA(DF, N, M):
DF = DF.fillna(0)
z = len(DF)
var = np.zeros(z)
var[0] = DF[0]
for i in range(1, z):
var[i] = (DF[i] * M + var[i - 1] * (N - M)) / N
for i in range(z):
DF[i] = var[i]
return DF
def ATR(close, high, low, n):
"""
真实波幅
:param close:
:param high:
:param low:
:param n:
:return:
"""
c, h, l_ = close, high, low
mtr = MAX(MAX((h - l_), ABS(REF(c, 1) - h)), ABS(REF(c, 1) - l_))
atr = MA(mtr, n)
return pd.DataFrame({'MTR': mtr, 'ATR': atr})
def HHV(number, n):
return pd.Series.rolling(number, n).max()
def LLV(number, n):
return pd.Series.rolling(number, n).min()
def SUM(number, n):
return pd.Series.rolling(number, n).sum()
def ABS(number):
return np.abs(number)
def MAX(A, B):
return np.maximum(A, B)
def MIN(A, B):
var = IF(A < B, A, B)
return var
def IF(COND, V1, V2):
var = np.flip(np.where(COND, V1, V2))
return pd.Series(var)[::-1]
def REF(DF, N):
var = DF.diff(N)
var = DF - var
return var
def STD(number, n):
return pd.Series.rolling(number, n).std()
def MACD(close, f, s, m):
"""
:param close:
:param f:
:param s:
:param m:
:return:
"""
EMAFAST = EMA(close, f)
EMASLOW = EMA(close, s)
DIFF = EMAFAST - EMASLOW
DEA = EMA(DIFF, m)
MACD = (DIFF - DEA) * 2
return pd.DataFrame({
'DIFF': round(DIFF, 2),
'DEA': round(DEA, 2), 'MACD': round(MACD, 2)})
def KDJ(close, high, low, n, m1, m2):
"""
:param close:
:param high:
:param low:
:param n:
:param m1:
:param m2:
:return:
"""
c, h, l = close, high, low
RSV = (c - LLV(l, n)) / (HHV(h, n) - LLV(l, n)) * 100
K = SMA(RSV, m1, 1)
D = SMA(K, m2, 1)
J = 3 * K - 2 * D
return pd.DataFrame({'KDJ_K': round(K, 2), 'KDJ_D': round(D, 2), 'KDJ_J': round(J, 2)})
def OSC(close, n, m):
"""
变动速率线
:param close:
:param n:
:param m:
:return:
"""
c = close
OS = (c - MA(c, n)) * 100
MAOSC = EMA(OS, m)
return pd.DataFrame({'OSC': OS, 'MAOSC': MAOSC})
def BBI(close, N1, N2, N3, N4):
"""
多空指标
:param close:
:param N1:
:param N2:
:param N3:
:param N4:
:return:
"""
bbi = (MA(close, N1) + MA(close, N2) + MA(close, N3) + MA(close, N4)) / 4
return pd.DataFrame({'BBI': round(bbi, 2)})
def BBIBOLL(close, n, m, n1=3, n2=6, n3=12, n4=24):
"""
多空布林线
:param close:
:param n1:
:param n2:
:param n3:
:param n4:
:param n:
:param m:
:return:
"""
bbi_boll = BBI(close, n1, n2, n3, n4)['BBI']
UPER = bbi_boll + m * STD(bbi_boll, n)
DOWN = bbi_boll - m * STD(bbi_boll, n)
return pd.DataFrame({'BBIBOLL': round(bbi_boll, 2), 'UPER': round(UPER, 2), 'DOWN': round(DOWN, 2)})
def PBX(close, n1, n2, n3, n4, n5, n6):
"""
瀑布线
:param close:
:param n1:
:param n2:
:param n3:
:param n4:
:param n5:
:param n6:
:return:
"""
c = close
PBX1 = (EMA(c, n1) + MA(c, 2 * n1) + MA(c, 4 * n1)) / 3
PBX2 = (EMA(c, n2) + MA(c, 2 * n2) + MA(c, 4 * n2)) / 3
PBX3 = (EMA(c, n3) + MA(c, 2 * n3) + MA(c, 4 * n3)) / 3
PBX4 = (EMA(c, n4) + MA(c, 2 * n4) + MA(c, 4 * n4)) / 3
PBX5 = (EMA(c, n5) + MA(c, 2 * n5) + MA(c, 4 * n5)) / 3
PBX6 = (EMA(c, n6) + MA(c, 2 * n6) + MA(c, 4 * n6)) / 3
return pd.DataFrame(
{'PBX1': round(PBX1, 2), 'PBX2': round(PBX2, 2), 'PBX3': round(PBX3, 2),
'PBX4': round(PBX4, 2), 'PBX5': round(PBX5, 2), 'PBX6': round(PBX6, 2)}
)
def BOLL(close, N): # 布林线
boll = MA(close, N)
UB = boll + 2 * STD(close, N)
LB = boll - 2 * STD(close, N)
return pd.DataFrame({'BOLL': round(boll, 2), 'UB': round(UB, 2), 'LB': round(LB, 2)})
def ROC(close, n, m):
"""
变动率指标
:param close:
:param n:
:param m:
:return:
"""
c = close
roc = 100 * (c - REF(c, n)) / REF(c, n)
maroc = MA(roc, m)
return pd.DataFrame({'ROC': round(roc, 2), 'MAROC': round(maroc, 2)})
def MTM(close, n, m):
"""
动量线
:param close:
:param n:
:param m:
:return:
"""
c = close
mtm = c - REF(c, n)
mtm_ma = MA(mtm, m)
return pd.DataFrame({'MTM': round(mtm, 2), 'MTMMA': round(mtm_ma, 2)})
def MFI(close, high, low, vol, n):
"""
资金指标
:param close:
:param high:
:param low:
:param vol:
:param n:
:return:
"""
c, h, l, v = close, high, low, vol
TYP = (c + h + l) / 3
V1 = SUM(IF(TYP > REF(TYP, 1), TYP * v, 0), n) / \
SUM(IF(TYP < REF(TYP, 1), TYP * v, 0), n)
mfi = 100 - (100 / (1 + V1))
return pd.DataFrame({'MFI': round(mfi, 2)})
def SKDJ(close, high, low, N, M):
c = close
LOWV = LLV(low, N)
HIGHV = HHV(high, N)
RSV = EMA((c - LOWV) / (HIGHV - LOWV) * 100, M)
K = EMA(RSV, M)
D = MA(K, M)
return pd.DataFrame({'SKDJ_K': round(K, 2), 'SKDJ_D': round(D, 2)})
def WR(close, high, low, N, N1):
"""
威廉指标
:param close:
:param high:
:param low:
:param N:
:param N1:
:return:
"""
c, h, l = close, high, low
WR1 = round(100 * (HHV(h, N) - c) / (HHV(h, N) - LLV(l, N)), 2)
WR2 = round(100 * (HHV(h, N1) - c) / (HHV(h, N1) - LLV(l, N1)), 2)
return pd.DataFrame({'WR1': round(WR1, 2), 'WR2': round(WR2, 2)})
def BIAS(DF, N1, N2, N3): # 乖离率
CLOSE = DF
BIAS1 = (CLOSE - MA(CLOSE, N1)) / MA(CLOSE, N1) * 100
BIAS2 = (CLOSE - MA(CLOSE, N2)) / MA(CLOSE, N2) * 100
BIAS3 = (CLOSE - MA(CLOSE, N3)) / MA(CLOSE, N3) * 100
DICT = {'BIAS1': BIAS1, 'BIAS2': BIAS2, 'BIAS3': BIAS3}
VAR = pd.DataFrame(DICT)
return VAR
def RSI(c, N1, N2, N3): # 相对强弱指标RSI1:SMA(MAX(CLOSE-LC,0),N1,1)/SMA(ABS(CLOSE-LC),N1,1)*100;
DIF = c - REF(c, 1)
RSI1 = round((SMA(MAX(DIF, 0), N1) / round(SMA(ABS(DIF), N1) * 100, 3)) * 10000, 2)
RSI2 = round((SMA(MAX(DIF, 0), N2) / round(SMA(ABS(DIF), N2) * 100, 3)) * 10000, 2)
RSI3 = round((SMA(MAX(DIF, 0), N3) / round(SMA(ABS(DIF), N3) * 100, 3)) * 10000, 2)
return pd.DataFrame({'RSI1': RSI1, 'RSI2': RSI2, 'RSI3': RSI3})
def ADTM(DF, N, M): # 动态买卖气指标
HIGH = DF['high']
LOW = DF['low']
OPEN = DF['open']
DTM = IF(OPEN <= REF(OPEN, 1), 0, MAX(
(HIGH - OPEN), (OPEN - REF(OPEN, 1))))
DBM = IF(OPEN >= REF(OPEN, 1), 0, MAX((OPEN - LOW), (OPEN - REF(OPEN, 1))))
STM = SUM(DTM, N)
SBM = SUM(DBM, N)
ADTM1 = IF(STM > SBM, (STM - SBM) / STM,
IF(STM == SBM, 0, (STM - SBM) / SBM))
MAADTM = MA(ADTM1, M)
DICT = {'ADTM': ADTM1, 'MAADTM': MAADTM}
VAR = pd.DataFrame(DICT)
return VAR
def DDI(DF, N, N1, M, M1): # 方向标准离差指数
H = DF['high']
L = DF['low']
DMZ = IF((H + L) <= (REF(H, 1) + REF(L, 1)), 0,
MAX(ABS(H - REF(H, 1)), ABS(L - REF(L, 1))))
DMF = IF((H + L) >= (REF(H, 1) + REF(L, 1)), 0,
MAX(ABS(H - REF(H, 1)), ABS(L - REF(L, 1))))
DIZ = SUM(DMZ, N) / (SUM(DMZ, N) + SUM(DMF, N))
DIF = SUM(DMF, N) / (SUM(DMF, N) + SUM(DMZ, N))
ddi = DIZ - DIF
ADDI = SMA(ddi, N1, M)
AD = MA(ADDI, M1)
DICT = {'DDI': ddi, 'ADDI': ADDI, 'AD': AD}
VAR = pd.DataFrame(DICT)
return VAR
ZIG_STATE_START = 0
ZIG_STATE_RISE = 1
ZIG_STATE_FALL = 2
def ZIG(d, k, n):
"""
之字转向指标当前价格变化超过 x% 时候变化
:param d: 交易日期
:param k: 价格
:param n: 系数
:return:
"""
x = round(n / 100, 2)
peer_i = 0
candidate_i = None
scan_i = 0
peers = [0]
z = np.zeros(len(k))
state = ZIG_STATE_START
while True:
scan_i += 1
if scan_i == len(k) - 1:
# 扫描到尾部
if candidate_i is None:
peer_i = scan_i
peers.append(peer_i)
else:
if state == ZIG_STATE_RISE:
if k[scan_i] >= k[candidate_i]:
print(d[scan_i], "1 --->>>", d[candidate_i])
peer_i = scan_i
peers.append(peer_i)
else:
peer_i = candidate_i
peers.append(peer_i)
peer_i = scan_i
peers.append(peer_i)
elif state == ZIG_STATE_FALL:
if k[scan_i] <= k[candidate_i]:
print(d[scan_i], "2 --->>>", d[candidate_i])
peer_i = scan_i
peers.append(peer_i)
else:
peer_i = candidate_i
peers.append(peer_i)
peer_i = scan_i
peers.append(peer_i)
break
if state == ZIG_STATE_START:
if k[scan_i] >= k[peer_i] * (1 + x):
print(d[scan_i], "3 --->>>", d[peer_i])
candidate_i = scan_i
state = ZIG_STATE_RISE
elif k[scan_i] <= k[peer_i] * (1 - x):
print(d[scan_i], "4 --->>>", d[peer_i])
candidate_i = scan_i
state = ZIG_STATE_FALL
elif state == ZIG_STATE_RISE:
if k[scan_i] >= k[candidate_i]:
candidate_i = scan_i
elif k[scan_i] <= k[candidate_i] * (1 - x):
print(d[scan_i], "5 --->>>", d[candidate_i])
peer_i = candidate_i
peers.append(peer_i)
state = ZIG_STATE_FALL
candidate_i = scan_i
elif state == ZIG_STATE_FALL:
if k[scan_i] <= k[candidate_i]:
print(d[scan_i], "6 --->>>", d[candidate_i])
candidate_i = scan_i
elif k[scan_i] >= k[candidate_i] * (1 + x):
print(d[scan_i], "7 --->>>", d[candidate_i])
peer_i = candidate_i
peers.append(peer_i)
state = ZIG_STATE_RISE
candidate_i = scan_i
for i in range(len(peers) - 1):
peer_start_i = peers[i]
peer_end_i = peers[i + 1]
start_value = k[peer_start_i]
end_value = k[peer_end_i]
a = (end_value - start_value) / (peer_end_i - peer_start_i) # 斜率
for j in range(peer_end_i - peer_start_i + 1):
z[j + peer_start_i] = start_value + a * j
return pd.Series(z), peers
def TROUGHBARS(z, p, m):
"""
m zig 波谷到当前的距离
:param z: zig 指标
:param p: zip 转折点
:param m: 系数
:return:
"""
trough_bars = np.zeros(len(z))
if len(z) > 3:
j = 1
# 判断第一个是谷还是峰 ,峰则取偶数,如果是谷,则取奇数
if z[0] > z[1]:
# 第一个是波谷
for i in range(len(p)):
peer = p[i]
j = i + m * 2
if 0 < i and len(p) > j and i % 2 == 1:
num = p[j] - peer - 1
trough_bars[p[j] - 1] = num
trough_bars[p[j]] = 1
if z[0] < z[1]:
# 第一个是波峰
for i in range(len(p)):
peer = p[i]
j = i + m * 2
if 0 < i and len(p) > j and i % 2 == 0:
num = p[j] - peer - 1
trough_bars[p[j] - 1] = num
trough_bars[p[j]] = 1
return pd.Series(trough_bars)
def CROSS(a, b):
"""
穿越信号
当a向上穿越b时标记1当a向下穿越b时标记-1没穿越标记0
:param obj:
:param ref:
:return:
"""
assert len(a) == len(b), '穿越信号输入维度不相等'
assert len(a) > 1, '穿越信号长度至少为2'
res = np.zeros(len(a))
for i in range(len(a) - 2, -1, -1):
if a[i + 1] <= b[i + 1] and a[i] > b[i] and a[i + 1] < a[i]:
# 向上穿越时标记1
res[i] = 1
elif a[i + 1] >= b[i + 1] and a[i] < b[i] and a[i + 1] > a[i]:
res[i] = -1
else:
res[i] = 0
# print(f"a+1 = {a[i + 1]}, b+1 = {b[i + 1]} , a = {a[i]} , b = {b[i]} , res = {res[i]}")
return pd.Series(res)
def _calc_slope(x):
return np.polyfit(range(len(x)), x, 1)[0]
def rolling_window(a, window):
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
strides = a.values.strides + (a.values.strides[-1],)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
def SLOPE(series, n):
"""
SLOPE(X,N) 返回线性回归斜率,N支持变量
参考https://blog.csdn.net/luhouxiang/article/details/113816062
"""
a = rolling_window(series, n)
obj = np.array([_calc_slope(x) for x in a])
new_obj = np.pad(obj, (len(series) - len(obj), 0), 'constant', constant_values=(np.nan, np.nan))
return new_obj
def GOLD_MACD(df_data: pd.DataFrame):
"""
黄金MACD指标
:param df_data:
:return:
"""
df = df_data[::-1].reset_index(drop=True)
CLOSE = df["close"]
d = df["trade_date"]
MACD = (EMA(CLOSE, 30) - REF(EMA(CLOSE, 30), 1)) / REF(EMA(CLOSE, 30), 1) * 100
DIF = EMA(SUM(MACD, 2), 5)
buy_1 = DIF > REF(DIF, 1)
buy_2 = DIF < REF(DIF, 1)
DEA = MA(DIF, 5)
return pd.DataFrame(
{'code': df['ts_code'], 'date': d, 'MACD': MACD, 'DIF': DIF, 'DEA': DEA, 'buy1': buy_1, "buy2": buy_2})
def DJCPX(df_data: pd.DataFrame):
"""
顶级操盘线 指标
:param df_data:
:return:
"""
df = df_data[::-1].reset_index(drop=True)
k = df["close"]
d = df["trade_date"]
# print(f'{d[i]} -->> {buy_1[i]} -->> {buy_2[i]} -->> {B[i]}')
VAR_200 = round((100 - ((90 * (HHV(df["high"], 20) - df["close"])) / (
HHV(df["high"], 20) - LLV(df["low"], 20)))), 2)
VAR_300 = round((100 - MA(
((100 * (HHV(df["high"], 5) - df["close"])) / (HHV(df["high"], 5) - LLV(df["low"], 5))),
34)), 2)
VAR_300_MA_5 = MA(VAR_300, 5)
# F:IF(CROSS(VAR200,MA(VAR300,5)),LOW * 0.98,DRAWNULL),CROSSDOT,LINETHICK3,COLORFF00FF;
# CROSS 上穿函数 CROSS(A,B)表示当A从下方向上穿过B时返回1,否则返回0
F = np.zeros(df.shape[0])
VAR_CROSS = CROSS(VAR_200, VAR_300_MA_5)
for i in range(df.shape[0]):
if VAR_CROSS[i] == 1:
F[i] = round(df["low"][i] * 0.98, 2)
# 重心:=(C+0.618*REF(C,1)+0.382*REF(C,1)+0.236*REF(C,3)+0.146*REF(C,4))/2.382;
ZX = round((k + (0.618 * REF(k, 1)) + (0.382 * REF(k, 1)) + (0.236 * REF(k, 3)) + (
0.146 * REF(k, 4))) / 2.382, 2)
# 【操盘线】:EMA(((SLOPE(C,22)*20)+C),55),COLORYELLOW,LINETHICK4;
CPX = round(EMA(((SLOPE(k, 22) * 20) + k), 55), 2)
# 【黄金线】:IF(重心>=【操盘线】,【操盘线】,DRAWNULL),COLORRED,LINETHICK2;
HJX = np.zeros(df.shape[0])
# 【空仓线】:IF(重心<【操盘线】,【操盘线】,DRAWNULL),COLORCYAN,LINETHICK2;
KCX = np.zeros(df.shape[0])
for i in range(df.shape[0]):
if ZX[i] >= CPX[i]:
HJX[i] = CPX[i]
else:
KCX[i] = CPX[i]
return pd.DataFrame(
{'code': df['ts_code'], 'date': d, 'F': F, '黄金线': HJX, '空仓线': KCX})
def CCI(DF, n: int = 14):
TP = (DF['low'] + DF['high'] + DF['close']) / 3
MA = TP.rolling(window=n).mean()
MD = TP.rolling(window=n).apply(lambda x: abs(x - x.mean()).mean(), raw=False)
return round((TP - MA) / (0.015 * MD), 2)
def bullish(DF, N):
"""
多头指标N项的递增序列
:param DF:
:param N:
:return:
"""
return pd.Series.rolling(DF, N).apply(lambda x: x.is_monotonic_increasing)
def bearish(DF, N):
"""
空头指标N项的递减序列
:param DF:
:param N:
:return:
"""
return pd.Series.rolling(DF, N).apply(lambda x: x.is_monotonic_decreasing)
def OBV(c, v, M):
# diff 计算相邻元素的差值
change = np.diff(c)
# sign 用于获取数组元素的符号的函数,对于正数,返回 1对于负数返回 -1对于零返回 0
# hstack 用于水平(按列)连接数组的函数
sig = np.hstack([[1], np.sign(change)])
# cumsum 计算累积和的方法。它将给定数组中的元素逐个累加
obv_ = np.cumsum(v * sig)
OBV = pd.Series(obv_)
MAOBV = MA(OBV, M)
return pd.DataFrame({'OBV': OBV, 'MAOBV': MAOBV})
def OBV_PLUS(DF, M):
"""
OBV策略升级TODO 还没完成
1. 增加价格相距大的那一天成交量的权重这可以更突出上升趋势和下降趋势
2. 当天的成交量以一定比例加入OBV中而不是将全天的成交量全部加入OBV中
:param DF:
:param M:
:return:
"""
CLOSE = DF['close']
VOL = DF['vol']
ref = REF(CLOSE, 1)
var_total = 0
for index, row in DF.iterrows():
if np.isnan(ref[index]):
var_total += VOL[index]
continue
if row["close"] > ref[index]:
vol = VOL[index] * 1
elif row["close"] == ref[index]:
vol = 0
else:
vol = VOL[index] * -1
var_total += vol
def ASI(OPEN, CLOSE, HIGH, LOW, M1=26, M2=10):
"""
# 振动升降指标
:param OPEN:
:param CLOSE:
:param HIGH:
:param LOW:
:param M1:
:param M2:
:return:
"""
LC = REF(CLOSE, 1)
AA = ABS(HIGH - LC)
BB = ABS(LOW - LC)
CC = ABS(HIGH - REF(LOW, 1))
DD = ABS(LC - REF(OPEN, 1))
R = IF((AA > BB) & (AA > CC), AA + BB / 2 + DD / 4, IF((BB > CC) & (BB > AA), BB + AA / 2 + DD / 4, CC + DD / 4))
X = (CLOSE - LC + (CLOSE - OPEN) / 2 + LC - REF(OPEN, 1))
SI = 16 * X / R * MAX(AA, BB)
ASI = SUM(SI, M1)
ASIT = MA(ASI, M2)
return {'ASI': ASI, 'ASIT': ASIT}
def DMI(CLOSE, HIGH, LOW, M1=14, M2=6): # 动向指标:结果和同花顺,通达信完全一致
TR = SUM(MAX(MAX(HIGH - LOW, ABS(HIGH - REF(CLOSE, 1))), ABS(LOW - REF(CLOSE, 1))), M1)
HD = HIGH - REF(HIGH, 1)
LD = REF(LOW, 1) - LOW
DMP = SUM(IF((HD > 0) & (HD > LD), HD, 0), M1)
DMM = SUM(IF((LD > 0) & (LD > HD), LD, 0), M1)
PDI = (DMP * 100) / TR
MDI = (DMM * 100) / TR
ADX = MA(ABS(MDI - PDI) / (PDI + MDI) * 100, M2)
ADXR = (ADX + REF(ADX, M2)) / 2
return {'PDI': round(PDI.fillna(0), 2), 'MDI': round(MDI.fillna(0), 2), 'ADX': round(ADX.fillna(0), 2),
'ADXR': round(ADXR.fillna(0), 2)}
def RM_KDJ(C, H, L, N, M1, M2):
RSV = (C - LLV(L, N)) / (HHV(H, N) - LLV(L, N)) * 100
K = RM_SMA(RSV, M1, 1)
D = RM_SMA(K, M2, 1)
J = 3 * K - 2 * D
return pd.DataFrame({'KDJ_K': round(K, 2), 'KDJ_D': round(D, 2), 'KDJ_J': round(J, 2)})
def INTPART(number):
number = number.fillna(0)
return number.astype(int)
def JXNH(CLOSE, OPEN, VOL):
VAR1 = MA(CLOSE, 5)
VAR2 = MA(CLOSE, 10)
VAR3 = MA(CLOSE, 30)
VARB = SUM(CLOSE * VOL * 100, 28) / SUM(VOL * 100, 28)
VARC = INTPART(VARB * 100) / 100
VARD = EMA(CLOSE, 5) - EMA(CLOSE, 10)
VARE = EMA(VARD, 9)
VAR13 = REF(VARE, 1)
VAR14 = VARE
VAR15 = VAR14 - VAR13
VAR16 = REF(VARD, 1)
VAR17 = VARD
VAR18 = VAR17 - VAR16
VAR19 = OPEN
VAR1A = CLOSE
JXNH = (VAR19 <= VAR1) & \
(VAR19 <= VAR2) & \
(VAR19 <= VAR3) & \
(VAR1A >= VAR1) & \
(VAR1A >= VARC) & (VAR15 > 0) & (VAR18 > 0)
return pd.DataFrame({'JXNH': JXNH})

176
utils/tdxUtil.py Normal file
View File

@ -0,0 +1,176 @@
import struct
from pathlib import Path
import pandas as pd
"""
通达信操作工具类
"""
class TdxUtil:
SECURITY_EXCHANGE = ["sz", "sh"]
SECURITY_COEFFICIENT = {"SH_A_STOCK": [0.01, 0.01], "SH_B_STOCK": [0.001, 0.01], "SH_INDEX": [0.01, 1.0],
"SH_FUND": [0.001, 1.0], "SH_BOND": [0.001, 1.0], "SZ_A_STOCK": [0.01, 0.01],
"SZ_B_STOCK": [0.01, 0.01], "SZ_INDEX": [0.01, 1.0], "SZ_FUND": [0.001, 0.01],
"SZ_BOND": [0.001, 0.01]}
SECURITY_TYPE = ["SH_A_STOCK", "SH_B_STOCK", "SH_INDEX", "SH_FUND", "SH_BOND", "SZ_A_STOCK", "SZ_B_STOCK",
"SZ_INDEX", "SZ_FUND", "SZ_BOND"]
def __init__(self, tdx_path: str):
self.f_name = None
self.path = tdx_path
# 通达信自选股路径
self.zxg_path = tdx_path + "\\T0002\\blocknew\\ZXG.blk"
# 上证日线
self.lday_sz_path = tdx_path + "\\vipdoc\\sz\\lday\\"
# 沪深日线
self.lday_sh_path = tdx_path + "\\vipdoc\\sh\\lday\\"
def set_zxg_file(self, cont: pd):
"""
通达信自选股写入
:param cont:
:return:
"""
# 将 DataFrame 内容逐行写入文本文件
with open(self.zxg_path, "w") as file:
for item in cont:
stock_code = str(item)
code = stock_code.split(".")[0]
if len(code) > 1:
if stock_code.startswith(('50', '51', '60', '688', '73', '90', '110', '113', '132', '204', '78')):
code = '1' + stock_code
if stock_code.startswith(('00', '13', '18', '15', '16', '18', '20', '30', '39', '115', '1318')):
code = '0' + stock_code
file.write(code + '\n')
# 关闭文件
file.close()
def get_zxg_file(self):
"""
读取通达信自选股
:return:
"""
f = open(self.zxg_path, 'r')
z = f.read()
f.close()
return z
def get_tdx_stock_data(self, code=None, freq=None, market=None):
"""
获取K线数据(日线1分钟线5分钟线30分钟线)
:param code: 股票代码
:param freq: 分钟数据1代表1分钟5代表5分钟
:param market: 上海证券交易所: sh, 深证证券交易所: sz, 北京证券交易所: bj
:return:
"""
# 获取全部的数据
if code is None:
return self.get_all_stock()
if freq is None:
return self.get_day_k(code, market)
else:
return self.get_freq_k(code, market, freq)
def get_all_stock(self):
"""
获取全部股票
:return:
"""
codes = []
for market in self.SECURITY_EXCHANGE:
files = Path(f'\\vipdoc\\{market}\\lday\\')
for file in files.iterdir():
if not file.is_file():
continue
tdx_code = file.name.split(".")[0]
market = tdx_code[:2]
code = tdx_code[2:]
res = self.get_security_type(code=code, market=market)
filter_stock = ["SZ_A_STOCK", "SH_A_STOCK"]
if res not in filter_stock:
continue
codes.append(code)
return codes
def get_day_k(self, code, market):
security_type = self.get_security_type(code, market)
if security_type not in self.SECURITY_TYPE:
print("Unknown security type !\n")
raise NotImplementedError
if security_type == "SZ_A_STOCK":
self.f_name = f"{self.lday_sz_path}sz{code}.day"
if security_type == "SH_A_STOCK":
self.f_name = f"{self.lday_sh_path}sh{code}.day"
print(self.f_name)
df = pd.DataFrame(columns=['trade_date', 'open', 'high', 'low', 'close', 'amount', 'volume'])
file_path = Path(self.f_name)
coefficient = [0.01, 0.01]
if not file_path.is_file():
raise f"{self.f_name} 不是一个文件!"
content = file_path.read_bytes()
record_struct = struct.Struct('<IIIIIfII')
print(len(content), record_struct.size)
for offset in range(0, len(content), record_struct.size):
row = record_struct.unpack_from(content, offset)
t_date = str(row[0])
# datestr = t_date[:4] + "-" + t_date[4:6] + "-" + t_date[6:]
close = row[4] * coefficient[0]
open = row[1] * coefficient[0]
new_row = [
t_date,
open,
row[2] * coefficient[0],
row[3] * coefficient[0],
close,
row[5],
row[6] * coefficient[1],
]
df.loc[len(df)] = new_row
return df
def get_freq_k(self, code, market, freq):
pass
def get_security_type(self, code, name=None, market=None):
exchange = str(market).lower()
code_head = code[:2]
if name is not None:
if 'ST' in name or '*' in name:
return None
if exchange == self.SECURITY_EXCHANGE[0]:
if code_head in ["00", "30"]:
return "SZ_A_STOCK"
elif code_head in ["20"]:
return "SZ_B_STOCK"
elif code_head in ["39"]:
return "SZ_INDEX"
elif code_head in ["15", "16"]:
return "SZ_FUND"
elif code_head in ["10", "11", "12", "13", "14"]:
return "SZ_BOND"
elif exchange == self.SECURITY_EXCHANGE[1]:
if code_head in ["60", "68"]: # 688XXX科创板
return "SH_A_STOCK"
elif code_head in ["90"]:
return "SH_B_STOCK"
elif code_head in ["00", "88", "99"]:
return "SH_INDEX"
elif code_head in ["50", "51"]:
return "SH_FUND"
elif code_head in ["01", "10", "11", "12", "13", "14", "20"]:
return "SH_BOND"
else:
# 如果没有标识只返回A股 TODO 不要创业板
if code_head in ["00"]:
return "SZ_A_STOCK"
elif code_head in ["60"]:
return "SH_A_STOCK"
else:
return None
if __name__ == '__main__':
tdx = TdxUtil("D:\\new_tdx")
print(tdx.get_tdx_stock_data("002448"))