重构部分内容
This commit is contained in:
parent
786723438e
commit
e2db3cf16e
|
@ -0,0 +1,22 @@
|
|||
mysql = {
|
||||
"url": "127.0.0.1",
|
||||
"port": "3306",
|
||||
"username": "root",
|
||||
"password": "qeadzc123",
|
||||
"database": "qnloft_hospital",
|
||||
}
|
||||
# 基本数据
|
||||
stock_info_db = "stock_info.db"
|
||||
# 板块相关数据
|
||||
stock_sector_db = "stock_sector.db"
|
||||
# 个股日线数据
|
||||
stock_daily_db = "stock_daily.db"
|
||||
# 分钟线
|
||||
stock_daily_freq_db = "stock_daily_freq.db"
|
||||
# redis_info = ['10.10.XXX', XXX]
|
||||
# mongo_info = ['10.10.XXX', XXX]
|
||||
# es_info = ['10.10.XXX', XXX]
|
||||
# sqlserver_info = ['10.10.XXX:XXX', 'XXX', 'XXX', 'XXX']
|
||||
# db2_info = ['10.10.XXX', XXX, 'XXX', 'XXX']
|
||||
# postgre_info = ['XXX', 'XXX', 'XXX', '10.10.XXX', XXX]
|
||||
# ck_info = ['10.10.XXX', XXX]
|
|
@ -0,0 +1,271 @@
|
|||
import json
|
||||
import traceback
|
||||
|
||||
import pandas
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session, aliased, declarative_base, clear_mappers
|
||||
from sqlalchemy import *
|
||||
|
||||
from DB import db_config as config
|
||||
|
||||
|
||||
class DbMain:
|
||||
|
||||
def __init__(self):
|
||||
# clear_mappers()
|
||||
self.inspector = None
|
||||
self.session = None
|
||||
self.engine = None
|
||||
|
||||
def get_session(self):
|
||||
# 绑定引擎
|
||||
session_factory = sessionmaker(bind=self.engine)
|
||||
# 创建数据库链接池,直接使用session即可为当前线程拿出一个链接对象conn
|
||||
# 内部会采用threading.local进行隔离
|
||||
Session = scoped_session(session_factory)
|
||||
return Session()
|
||||
|
||||
# ===================== insert or update方法=============================
|
||||
def insert_all_entry(self, entries):
|
||||
try:
|
||||
self.create_table(entries)
|
||||
self.session.add_all(entries)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def insert_entry(self, entry):
|
||||
try:
|
||||
self.create_table(entry)
|
||||
self.session.add(entry)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def insert_or_update(self, entry, query_conditions):
|
||||
"""
|
||||
insert_or_update 的是用需要在 对象中新增to_dict方法,将需要更新的字段转成字典
|
||||
:param entry:
|
||||
:param query_conditions:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self.create_table(entry)
|
||||
conditions = " AND ".join(f"{key} = '{value}'" for key, value in query_conditions.items())
|
||||
select_sql = text(f"select count(1) from {entry.__tablename__} where 1=1 and {conditions}")
|
||||
result = self.session.execute(select_sql)
|
||||
# 如果查询有结果,则执行update操作
|
||||
if result.scalar() > 0:
|
||||
if hasattr(entry, 'to_dict'):
|
||||
formatted_attributes = ", ".join([f'{attr} = "{value}"' for attr, value in entry.to_dict().items()])
|
||||
update_sql = text(f"UPDATE {entry.__tablename__} SET {formatted_attributes} WHERE {conditions}")
|
||||
self.session.execute(update_sql)
|
||||
else:
|
||||
raise Exception("对象 不包含 to_dict 方法,请添加!")
|
||||
else:
|
||||
# 执行新增错做
|
||||
self.insert_entry(entry)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.expire_all()
|
||||
self.session.close()
|
||||
|
||||
def create_table(self, entries):
|
||||
if isinstance(entries, list):
|
||||
table_name = entries[0].__tablename__
|
||||
else:
|
||||
table_name = entries.__tablename__
|
||||
# 检查表是否存在,如果不存在则创建
|
||||
if not self.inspector.has_table(table_name):
|
||||
if isinstance(entries, list):
|
||||
entries[0].metadata.create_all(self.engine)
|
||||
else:
|
||||
entries.metadata.create_all(self.engine)
|
||||
|
||||
def pandas_insert(self, data=pandas, table_name=""):
|
||||
"""
|
||||
新增数据操作,类型是pandas
|
||||
:param data:
|
||||
:param table_name:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
data.to_sql(table_name, con=self.engine.connect(), if_exists='append', index=True, index_label='id')
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
self.engine.dispose()
|
||||
|
||||
# ===================== select 方法=============================
|
||||
def query_by_id(self, model, db_id):
|
||||
try:
|
||||
return self.session.query(model).filter_by(id=db_id).all()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def pandas_query_by_model(self, model, order_col=None, page_number=None, page_size=None):
|
||||
"""
|
||||
使用pandas的sql引擎执行
|
||||
:param page_size:每页的记录数
|
||||
:param page_number:第N页
|
||||
:param model:
|
||||
:param order_col:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# 判断表是否存在
|
||||
if self.has_table(model):
|
||||
query = self.session.query(model)
|
||||
if order_col is not None:
|
||||
query = query.order_by(order_col)
|
||||
if page_number is not None and page_size is not None:
|
||||
offset = (page_number - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
return pandas.read_sql_query(query.statement, self.engine.connect())
|
||||
return pandas.DataFrame()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
self.engine.dispose()
|
||||
|
||||
def pandas_query_by_sql(self, stmt=""):
|
||||
"""
|
||||
使用pandas的sql引擎执行
|
||||
:param stmt:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
return pandas.read_sql_query(sql=stmt, con=self.engine.connect())
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
self.engine.dispose()
|
||||
|
||||
def pandas_query_by_condition(self, model, query_condition):
|
||||
try:
|
||||
# 当需要根据多个条件进行查询操作时
|
||||
# query_condition = and_(
|
||||
# StockDaily.trade_date == '20230823',
|
||||
# StockDaily.symbol == 'ABC'
|
||||
# )
|
||||
query = self.session.query(model).filter(query_condition).order_by(model.id)
|
||||
return self.pandas_query_by_sql(stmt=query.statement).reset_index()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
# ===================== delete 方法=============================
|
||||
def delete_by_id(self, model, db_id):
|
||||
try:
|
||||
# 使用 delete() 方法删除符合条件的记录
|
||||
self.session.query(model).filter_by(id=db_id).delete()
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def delete_by_condition(self, model, delete_condition):
|
||||
try:
|
||||
# 使用 delete() 方法删除符合条件的记录
|
||||
# 定义要删除的记录的条件
|
||||
# 例如,假设你要删除 trade_date 为 '20230823' 的记录
|
||||
# delete_condition = StockDaily.trade_date == '20230823'
|
||||
# 当需要根据多个条件进行删除操作时
|
||||
# delete_condition = and_(
|
||||
# StockDaily.trade_date == '20230823',
|
||||
# StockDaily.symbol == 'ABC'
|
||||
# )
|
||||
self.session.query(model).filter(delete_condition).delete()
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def delete_all_table(self, model): # 清空表数据
|
||||
try:
|
||||
self.session.query(model).delete()
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
# ===================== 其它 方法=============================
|
||||
def has_table(self, entries):
|
||||
if isinstance(entries, list):
|
||||
table_name = entries[0].__tablename__
|
||||
else:
|
||||
table_name = entries.__tablename__
|
||||
# 检查表是否存在,如果不存在则创建
|
||||
return self.inspector.has_table(table_name)
|
||||
|
||||
def execute_sql(self, s):
|
||||
try:
|
||||
sql_text = text(s)
|
||||
return self.session.execute(sql_text)
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def execute_sql_to_pandas(self, s):
|
||||
try:
|
||||
sql_text = text(s)
|
||||
res = self.session.execute(sql_text)
|
||||
return pandas.DataFrame(res.fetchall(), columns=res.keys())
|
||||
except Exception as e:
|
||||
trace = traceback.extract_tb(e.__traceback__)
|
||||
for filename, lineno, funcname, source in trace:
|
||||
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
||||
f"错误内容:{traceback.format_exc()}")
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
|
@ -0,0 +1,20 @@
|
|||
from sqlalchemy import Column, Integer, String, Index
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
'''
|
||||
概念板块数据表
|
||||
'''
|
||||
class ConceptSector(declarative_base()):
|
||||
__tablename__ = 'stock_concept_sector'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
trade_date = Column(String, nullable=False)
|
||||
sector_name = Column(String)
|
||||
sector_code = Column(String)
|
||||
pct_change = Column(String)
|
||||
total_market = Column(String)
|
||||
rising_count = Column(String)
|
||||
falling_count = Column(String)
|
||||
new_price = Column(String)
|
||||
leading_stock = Column(String)
|
||||
leading_stock_pct_change = Column(String)
|
|
@ -0,0 +1,27 @@
|
|||
# 定义表格映射类
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
'''
|
||||
行业板块数据表
|
||||
'''
|
||||
class IndustrySector(declarative_base()):
|
||||
__tablename__ = 'stock_industry_sector'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
trade_date = Column(String, nullable=False)
|
||||
sector_name = Column(String)
|
||||
pct_change = Column(String)
|
||||
total_volume = Column(String)
|
||||
total_turnover = Column(String)
|
||||
net_inflows = Column(String)
|
||||
rising_count = Column(String)
|
||||
falling_count = Column(String)
|
||||
average_price = Column(String)
|
||||
leading_stock = Column(String)
|
||||
leading_stock_latest_price = Column(String)
|
||||
leading_stock_pct_change = Column(String)
|
||||
|
||||
# 创建索引
|
||||
# trade_date_index = Index('idx_trade_date', IndustrySector.trade_date)
|
||||
# sector_name_index = Index('idx_sector_name', IndustrySector.sector_name)
|
|
@ -0,0 +1,42 @@
|
|||
# 定义表格映射类
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
|
||||
# 创建 StockBasic 映射类
|
||||
class StockBasic(declarative_base()):
|
||||
__tablename__ = 'stock_basic'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
ts_code = Column(String, nullable=False, unique=True)
|
||||
symbol = Column(String, nullable=False)
|
||||
name = Column(String)
|
||||
comp_name = Column(String)
|
||||
comp_name_en = Column(String)
|
||||
isin_code = Column(String)
|
||||
exchange = Column(String)
|
||||
list_board = Column(String)
|
||||
list_date = Column(String)
|
||||
delist_date = Column(String)
|
||||
crncy_code = Column(String)
|
||||
pinyin = Column(String)
|
||||
list_board_name = Column(String)
|
||||
is_shsc = Column(String)
|
||||
comp_code = Column(String)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StockBasic(ts_code='{self.ts_code}', name='{self.name}', exchange='{self.exchange}')>"
|
||||
|
||||
def to_dict(self):
|
||||
# 定义要保留的属性列表
|
||||
allowed_attributes = ['delist_date', 'is_shsc', 'comp_name', 'comp_name_en']
|
||||
|
||||
# 创建字典并添加需要保留的属性
|
||||
obj_dict = {}
|
||||
for attr in allowed_attributes:
|
||||
obj_dict[attr] = getattr(self, attr)
|
||||
|
||||
return obj_dict
|
|
@ -0,0 +1,15 @@
|
|||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
'''
|
||||
概念板块个股
|
||||
'''
|
||||
class StockByConceptSector(declarative_base()):
|
||||
__tablename__ = 'stock_by_concept_sector'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
update_date = Column(String, nullable=False)
|
||||
sector_name = Column(String)
|
||||
sector_code = Column(String)
|
||||
stock_name = Column(String)
|
||||
stock_code = Column(String)
|
|
@ -0,0 +1,15 @@
|
|||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
'''
|
||||
行业板块个股
|
||||
'''
|
||||
class StockByIndustrySector(declarative_base()):
|
||||
__tablename__ = 'stock_by_industry_sector'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
update_date = Column(String, nullable=False)
|
||||
sector_name = Column(String)
|
||||
sector_code = Column(String)
|
||||
stock_name = Column(String)
|
||||
stock_code = Column(String)
|
|
@ -0,0 +1,58 @@
|
|||
# 定义表格映射类
|
||||
# 定义表格映射类
|
||||
from sqlalchemy import Column, Integer, String, Float
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
|
||||
def get_stock_daily(table_name):
|
||||
class StockDaily(declarative_base()):
|
||||
__tablename__ = table_name
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
ts_code = Column(String, nullable=False) # ts代码
|
||||
trade_date = Column(String, nullable=False) # 交易日期
|
||||
crncy_code = Column(String) # 货币代码
|
||||
pre_close = Column(Float) # 昨收盘价(元)
|
||||
open = Column(Float) # 开盘价(元)
|
||||
high = Column(Float) # 最高价(元)
|
||||
low = Column(Float) # 最低价(元)
|
||||
close = Column(Float) # 收盘价(元)
|
||||
change = Column(Float) # 涨跌(元)
|
||||
pct_chg = Column(Float) # 涨跌幅(%)
|
||||
volume = Column(Float) # 成交量(手)
|
||||
amount = Column(Float) # 成交金额(千元)
|
||||
adj_pre_close = Column(Float) # 复权昨收盘价(元)
|
||||
adj_open = Column(Float) # 复权开盘价(元)
|
||||
adj_high = Column(Float) # 复权最高价(元)
|
||||
adj_low = Column(Float) # 复权最低价(元)
|
||||
adj_close = Column(Float) # 复权收盘价(元)
|
||||
adj_factor = Column(Float) # 复权因子
|
||||
avg_price = Column(Float) # 均价(VWAP)
|
||||
trade_status = Column(String) # 交易状态
|
||||
turnover_rate = Column(Float) # 换手率
|
||||
turnover_rate_f = Column(Float) # 换手率(自由流通股)
|
||||
volume_ratio = Column(Float) # 量比
|
||||
pe = Column(Float) # 市盈率(总市值/净利润)
|
||||
pe_ttm = Column(Float) # 市盈率(TTM)
|
||||
pb = Column(Float) # 市净率(总市值/净资产)
|
||||
ps = Column(Float) # 市销率
|
||||
ps_ttm = Column(Float) # 市销率(TTM)
|
||||
dv_ratio = Column(Float) # 股息率
|
||||
dv_ttm = Column(Float) # 滚动股息率
|
||||
total_share = Column(Float) # 总股本
|
||||
float_share = Column(Float) # 流通股本
|
||||
free_share = Column(Float) # 自由流通股本
|
||||
total_mv = Column(Float) # 总市值
|
||||
circ_mv = Column(Float) # 流通市值
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<StockDaily(ts_code='{self.ts_code}', trade_date='{self.trade_date}', close='{self.close}')>"
|
||||
|
||||
def to_dict(self):
|
||||
return {"ts_code": self.ts_code, "trade_date": self.trade_date, "pre_close": self.pre_close,
|
||||
"turnover_rate": self.turnover_rate}
|
||||
|
||||
return StockDaily
|
|
@ -0,0 +1,26 @@
|
|||
from sqlalchemy import Column, Integer, String, Float
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
|
||||
def get_stock_daily_freq(table_name):
|
||||
class StockDailyFreq(declarative_base()):
|
||||
__tablename__ = table_name
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
trade_date = Column(String) # 交易日期
|
||||
time = Column(String) # 时间
|
||||
open = Column(Float) # 开盘价
|
||||
close = Column(Float) # 收盘价
|
||||
high = Column(Float) # 最高价
|
||||
low = Column(Float) # 最低价
|
||||
vol = Column(Float) # 成交量(注意单位:手)
|
||||
amount = Column(Float) # 成交额
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
return {"trade_date": self.trade_date, "time": self.time,
|
||||
"close": self.close}
|
||||
|
||||
return StockDailyFreq
|
|
@ -0,0 +1,26 @@
|
|||
from DB.db_main import *
|
||||
|
||||
|
||||
def create_mysql_engine():
|
||||
# 创建引擎
|
||||
return create_engine(
|
||||
f"mysql+pymysql://{config.mysql['username']}:{config.mysql['password']}@{config.mysql['url']}:{config.mysql['port']}/{config.mysql['database']}?charset=utf8mb4",
|
||||
# "mysql+pymysql://tom@127.0.0.1:3306/db1?charset=utf8mb4", # 无密码时
|
||||
# 超过链接池大小外最多创建的链接
|
||||
max_overflow=0,
|
||||
# 链接池大小
|
||||
pool_size=5,
|
||||
# 链接池中没有可用链接则最多等待的秒数,超过该秒数后报错
|
||||
pool_timeout=10,
|
||||
# 多久之后对链接池中的链接进行一次回收
|
||||
pool_recycle=1,
|
||||
# 查看原生语句(未格式化)
|
||||
echo=True
|
||||
)
|
||||
|
||||
|
||||
class MysqlDbMain(DbMain):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.engine = create_mysql_engine()
|
||||
self.session = self.get_session()
|
|
@ -0,0 +1,50 @@
|
|||
from pathlib import Path
|
||||
|
||||
from DB.db_main import *
|
||||
import platform
|
||||
import logging
|
||||
|
||||
|
||||
class SqliteDbMain(DbMain):
|
||||
def __init__(self, database_name):
|
||||
# logging.basicConfig()
|
||||
# logging.getLogger('sqlalchemy.engine').setLevel(logging.ERROR)
|
||||
# logging.getLogger('sqlalchemy.pool').setLevel(logging.ERROR)
|
||||
# logging.getLogger('sqlalchemy.dialects').setLevel(logging.ERROR)
|
||||
# logging.getLogger('sqlalchemy.orm').setLevel(logging.ERROR)
|
||||
self.database_name = database_name
|
||||
super().__init__()
|
||||
self.engine = self.__create_sqlite_engine()
|
||||
self.engine_path = self.__get_path()
|
||||
self.session = self.get_session()
|
||||
self.inspector = inspect(self.engine)
|
||||
|
||||
def __get_path(self):
|
||||
sys_platform = platform.platform().lower()
|
||||
print(f'当前操作系统:{platform.platform()}')
|
||||
__engine = ''
|
||||
if 'windows' in sys_platform.lower():
|
||||
__engine = f"E:\\sqlite_db\\stock_db\\{self.database_name}"
|
||||
elif 'macos' in sys_platform.lower():
|
||||
__engine = f"/Users/renmeng/Documents/sqlite_db/{self.database_name}"
|
||||
else:
|
||||
__engine = f"{self.database_name}"
|
||||
return __engine
|
||||
|
||||
def __create_sqlite_engine(self):
|
||||
sys_platform = platform.platform().lower()
|
||||
__engine = ''
|
||||
if 'windows' in sys_platform.lower():
|
||||
__engine = f"sqlite:///E:\\sqlite_db\\stock_db\\{self.database_name}"
|
||||
elif 'macos' in sys_platform.lower():
|
||||
__engine = f"sqlite:////Users/renmeng/Documents/sqlite_db/{self.database_name}"
|
||||
else:
|
||||
__engine = f"sqlite:///{self.database_name}"
|
||||
print(f"当前__engine是:{__engine}")
|
||||
return create_engine(__engine, pool_size=10, pool_timeout=10, echo=True)
|
||||
|
||||
def get_db_size(self):
|
||||
file_size = Path(self.engine_path).stat().st_size
|
||||
total = f"{file_size / (1024 * 1024):.2f} MB"
|
||||
print(f"文件大小: {file_size / (1024 * 1024):.2f} MB")
|
||||
return total
|
|
@ -0,0 +1,119 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from DB.model.StockDailyFreq import get_stock_daily_freq
|
||||
from DB.sqlite_db_main import SqliteDbMain, config
|
||||
from utils.tdxUtil import TdxUtil
|
||||
from utils.comm import *
|
||||
from 基本信息入库 import StockInfoMain
|
||||
|
||||
|
||||
class StockDailyFreqMain:
|
||||
|
||||
def __init__(self):
|
||||
self.code_res = StockInfoMain().get_stock_basic()
|
||||
self.db_main = SqliteDbMain(config.stock_daily_freq_db)
|
||||
|
||||
def __filter_stock(self, code, name):
|
||||
if 'ST' in name:
|
||||
return False
|
||||
if code.startswith('30'):
|
||||
return False
|
||||
if code.startswith('68'):
|
||||
return False
|
||||
return True
|
||||
|
||||
def init_data(self):
|
||||
"""
|
||||
全量入库操作
|
||||
:return:
|
||||
"""
|
||||
# Print the results
|
||||
for result in self.code_res:
|
||||
tdx_util = TdxUtil("")
|
||||
s_type = tdx_util.get_security_type(code=result.ts_code)
|
||||
if s_type in tdx_util.SECURITY_TYPE and self.__filter_stock(result.ts_code, result.name):
|
||||
# 设置开始和结束时间
|
||||
start_time = f"{(datetime.now() - timedelta(days=365 * 3)).strftime('%Y-%m-%d')} 09:30:00"
|
||||
end_time = f"{datetime.now().strftime('%Y-%m-%d')} 15:00:00"
|
||||
print(
|
||||
f"{result.id},{result.ts_code} --> start_time={trade_date} 09:30:00 --->> end_time={trade_date} 15:00:00")
|
||||
i = 0
|
||||
df = None
|
||||
while True:
|
||||
try:
|
||||
if i > 6:
|
||||
break
|
||||
df = xcsc_pro.stk_mins(ts_code=result.ts_code, freq='30min',
|
||||
start_time=start_time,
|
||||
end_time=end_time)[::-1]
|
||||
except Exception as e:
|
||||
i += 1
|
||||
# 捕获超时异常
|
||||
print("请求超时,等待2分钟后重试...")
|
||||
time.sleep(120) # 休眠2分钟
|
||||
print(f"{result.ts_code},{result.name} 获取 stk_mins _ 30 数据完毕!")
|
||||
# 创建表
|
||||
table_name = str(result.ts_code).split(".")[0] + "_daily_freq_30"
|
||||
new_table_class = get_stock_daily_freq(table_name=table_name)
|
||||
self.db_main.create_table(new_table_class)
|
||||
entries = []
|
||||
for index, row in df.iterrows():
|
||||
min_time = row['trade_time']
|
||||
# 将字符串解析为日期时间对象
|
||||
trade_time = datetime.strptime(min_time, '%Y-%m-%d %H:%M:%S').strftime('%Y-%m-%d')
|
||||
entry = new_table_class(
|
||||
trade_date=trade_time,
|
||||
time=min_time,
|
||||
open=row['open'],
|
||||
close=row['close'],
|
||||
high=row['high'],
|
||||
low=row['low'],
|
||||
vol=row['vol'],
|
||||
amount=row['amount'],
|
||||
)
|
||||
print(entry.to_dict())
|
||||
entries.append(entry)
|
||||
self.db_main.insert_all_entry(entries)
|
||||
|
||||
def task_data(self, trade_date=datetime.now().strftime('%Y%m%d')):
|
||||
print(len(self.code_res))
|
||||
for result in self.code_res:
|
||||
tdx_util = TdxUtil("")
|
||||
s_type = tdx_util.get_security_type(code=result.ts_code)
|
||||
if s_type in tdx_util.SECURITY_TYPE and self.__filter_stock(result.ts_code, result.name):
|
||||
print(
|
||||
f"{result.id},{result.ts_code} --> start_time={trade_date} 09:30:00 --->> end_time={trade_date} 15:00:00")
|
||||
|
||||
df = xcsc_pro.stk_mins(ts_code=result.ts_code, freq='30min',
|
||||
start_time=f"{trade_date} 09:30:00",
|
||||
end_time=f"{trade_date} 15:00:00")
|
||||
print(f"{result.ts_code},{result.name} 获取 stk_mins _ 30 数据完毕!")
|
||||
for index, row in df.iterrows():
|
||||
table_name = str(result.ts_code).split(".")[0] + "_daily_freq_30"
|
||||
new_table_class = get_stock_daily_freq(table_name=table_name)
|
||||
entry = new_table_class(
|
||||
trade_date=trade_date,
|
||||
time=row['trade_time'],
|
||||
open=row['open'],
|
||||
close=row['close'],
|
||||
high=row['high'],
|
||||
low=row['low'],
|
||||
vol=row['vol'],
|
||||
amount=row['amount'],
|
||||
)
|
||||
print(entry.to_dict())
|
||||
self.db_main.insert_entry(entry)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
current_date = datetime.now()
|
||||
if if_run(current_date):
|
||||
trade_date = current_date.strftime('%Y-%m-%d')
|
||||
print(trade_date)
|
||||
main = StockDailyFreqMain()
|
||||
# main.init_data()
|
||||
main.task_data(trade_date)
|
||||
|
||||
# df = xcsc_pro.stk_mins(ts_code='600095.SH', start_time='2022-08-24 09:30:00', end_time='2023-08-24 15:00:00',
|
||||
# freq='30min')
|
||||
# print(df)
|
|
@ -0,0 +1,71 @@
|
|||
from sqlalchemy import and_
|
||||
|
||||
from DB.model.StockBasic import StockBasic
|
||||
from DB.sqlite_db_main import SqliteDbMain, config
|
||||
from utils.comm import *
|
||||
|
||||
|
||||
class StockInfoMain:
|
||||
|
||||
def __init__(self):
|
||||
self.db_main = SqliteDbMain(config.stock_info_db)
|
||||
self.session = self.db_main.get_session()
|
||||
|
||||
def insert_stock_basic(self):
|
||||
pro = xcsc_pro
|
||||
info = ['SZSE', 'SSE']
|
||||
# entries = []
|
||||
for item in info:
|
||||
# SSE: 上交所 SZSE: 深交所
|
||||
df = pro.stock_basic(exchange=item)
|
||||
for index, row in df.iterrows():
|
||||
entry = StockBasic(**row)
|
||||
self.db_main.insert_or_update(entry, query_conditions={'ts_code': row['ts_code']})
|
||||
|
||||
def get_stock_basic(self, ts_code=None, symbol=None, restart_id=0):
|
||||
sql_cond = and_(
|
||||
StockBasic.delist_date.is_('None')
|
||||
)
|
||||
try:
|
||||
if restart_id > 0:
|
||||
sql_cond = and_(
|
||||
StockBasic.id >= restart_id,
|
||||
StockBasic.delist_date.is_('None')
|
||||
)
|
||||
if ts_code is not None:
|
||||
if isinstance(ts_code, str):
|
||||
sql_cond = and_(
|
||||
StockBasic.ts_code.is_(ts_code),
|
||||
StockBasic.delist_date.is_('None')
|
||||
)
|
||||
elif isinstance(ts_code, list):
|
||||
sql_cond = and_(
|
||||
StockBasic.ts_code.in_(ts_code),
|
||||
StockBasic.delist_date.is_('None')
|
||||
)
|
||||
elif symbol is not None:
|
||||
if isinstance(symbol, str):
|
||||
sql_cond = and_(
|
||||
StockBasic.symbol.is_(symbol),
|
||||
StockBasic.delist_date.is_('None')
|
||||
)
|
||||
elif isinstance(symbol, list):
|
||||
sql_cond = and_(
|
||||
StockBasic.ts_code.in_(symbol),
|
||||
StockBasic.delist_date.is_('None')
|
||||
)
|
||||
results = self.session.query(StockBasic).filter(sql_cond).all()
|
||||
return results
|
||||
finally:
|
||||
self.session.close()
|
||||
|
||||
|
||||
# try:
|
||||
# results = self.session.query(StockBasic).filter(StockBasic.delist_date.is_(None)).all()
|
||||
# return results
|
||||
# finally:
|
||||
# self.session.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(StockInfoMain().get_stock_basic(restart_id=5))
|
|
@ -0,0 +1,124 @@
|
|||
from utils.comm import *
|
||||
from snownlp import SnowNLP
|
||||
import jieba
|
||||
import jieba.posseg as pseg
|
||||
import jieba.analyse
|
||||
|
||||
|
||||
class News:
|
||||
def __init__(self, trade_date=datetime.now()):
|
||||
self.trade_date = trade_date.strftime('%Y%m%d')
|
||||
|
||||
@staticmethod
|
||||
def preprocess(text):
|
||||
seg_list = jieba.lcut(text) # 文本预处理及分词
|
||||
return seg_list
|
||||
|
||||
@staticmethod
|
||||
def sentiment_analysis(text): # 情感分析
|
||||
s = SnowNLP(text)
|
||||
sentiment = s.sentiments # 情感得分,范围0-1,越接近1表示正面情感,越接近0表示负面情感
|
||||
return sentiment
|
||||
|
||||
@staticmethod
|
||||
def extract_keywords(text, top_k=5): # 关键词提取
|
||||
keywords = jieba.analyse.extract_tags(text, topK=top_k, withWeight=False, allowPOS=()) # 提取关键词,默认返回前5个
|
||||
return keywords
|
||||
|
||||
@staticmethod
|
||||
def named_entity_recognition(text): # 命名实体识别
|
||||
words = pseg.cut(text)
|
||||
entities = []
|
||||
for word, flag in words:
|
||||
# 识别人名、地名、机构名、其它专名
|
||||
if flag in ['nr', 'ns', 'nt', 'nz']:
|
||||
entities.append((word, flag))
|
||||
return entities
|
||||
|
||||
def result(self, news):
|
||||
print("原文:", news)
|
||||
preprocessed_news = self.preprocess(news)
|
||||
sentiment_score = self.sentiment_analysis(news)
|
||||
keywords = self.extract_keywords(news, top_k=3)
|
||||
entities = self.named_entity_recognition(news)
|
||||
|
||||
print("分词结果:", preprocessed_news)
|
||||
print("情感分析得分:", round(sentiment_score, 2))
|
||||
print("关键词提取:", keywords)
|
||||
print("命名实体识别:", entities)
|
||||
|
||||
def news_cctv(self):
|
||||
for _ in range(5):
|
||||
try:
|
||||
# 中央新闻
|
||||
df = ak.news_cctv(date=self.trade_date)
|
||||
# 情感分析
|
||||
# df['title'].apply(lambda x: self.result(x))
|
||||
# print(df['情感'])
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f"{self.trade_date} 日的新闻咨询拉取发生错误!{e.__traceback__}")
|
||||
time.sleep(1)
|
||||
continue
|
||||
else:
|
||||
print("中央新闻 5次出现错误,请关注!!!")
|
||||
return None
|
||||
|
||||
def news_cls(self):
|
||||
for _ in range(5):
|
||||
try:
|
||||
# 财联社-电报
|
||||
df = ak.stock_telegraph_cls()
|
||||
# 情感分析
|
||||
# df['title'].apply(lambda x: self.result(x))
|
||||
# print(df['情感'])
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f"{self.trade_date} 财联社-电报 咨询拉取发生错误!{e.__traceback__}")
|
||||
continue
|
||||
else:
|
||||
print("财联社-电报 5次出现错误,请关注!!!")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def news_stock_by_code(symbol, date_time=None):
|
||||
# 个股新闻
|
||||
stock_news_em_df = ak.stock_news_em(symbol=symbol)
|
||||
# 将发布时间列转换为日期时间类型
|
||||
stock_news_em_df['发布时间'] = pd.to_datetime(stock_news_em_df['发布时间'])
|
||||
# print(stock_news_em_df)
|
||||
if date_time is not None:
|
||||
if len(date_time.split("-")) == 1:
|
||||
str_date = datetime.strptime(str(date_time), '%Y%m%d').strftime('%Y-%m-%d')
|
||||
else:
|
||||
str_date = datetime.strptime(str(date_time), '%Y-%m-%d').strftime('%Y-%m-%d')
|
||||
else:
|
||||
# 获取今天的日期
|
||||
str_date = datetime.now().date()
|
||||
# 使用日期过滤器筛选出今天的新闻
|
||||
today_news = stock_news_em_df[stock_news_em_df['发布时间'].dt.date == str_date]
|
||||
# 打印今天的新闻
|
||||
print(today_news['新闻标题'])
|
||||
print(today_news["新闻内容"])
|
||||
# today_news['新闻标题'].apply(lambda x: result(x))
|
||||
return today_news
|
||||
|
||||
def html_page_data(self):
|
||||
cctv_df = self.news_cctv()
|
||||
cls_df = self.news_cls()
|
||||
cctv_content, cls_content = '', ''
|
||||
for index, row in cctv_df.iterrows():
|
||||
cctv_content += f'<tr><th scope="row">{row["title"]}</th>' \
|
||||
f'<td>开发中...</td>'
|
||||
cctv_content += '</tr>'
|
||||
for index, row in cls_df.iterrows():
|
||||
if len(row["标题"]) > 1:
|
||||
cls_content += f'<tr><th scope="row">{row["标题"]}</th>' \
|
||||
f'<td>开发中...</td>'
|
||||
cls_content += '</tr>'
|
||||
return cctv_content, cls_content
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
date_obj = datetime.strptime('20231019', "%Y%m%d")
|
||||
print(News().news_cctv())
|
|
@ -0,0 +1,285 @@
|
|||
import ftplib
|
||||
import os
|
||||
from datetime import datetime
|
||||
from ftplib import FTP, error_perm
|
||||
from pathlib import Path
|
||||
|
||||
import akshare as ak
|
||||
import efinance as ef
|
||||
import qstock as qs
|
||||
import tushare as ts
|
||||
import pandas as pd
|
||||
import time
|
||||
import random
|
||||
from tabulate import tabulate
|
||||
import platform
|
||||
import xcsc_tushare as xc
|
||||
|
||||
token = '0718534658b9d91b3f03dc8b220e4062193ebf6f6414d036505165e1'
|
||||
|
||||
xc.set_token('bd4f26f1eca8d660bd23e229260df46002d630d6e1fb9226380edec8')
|
||||
# 仿真环境用这个方式可以连上,生产环境用这个连不上
|
||||
xcsc_pro = xc.pro_api(env='prd', server='http://116.128.206.39:7172')
|
||||
|
||||
# 显示所有列
|
||||
pd.set_option('display.max_columns', None)
|
||||
# 显示所有行
|
||||
pd.set_option('display.max_rows', None)
|
||||
# 输出不折行
|
||||
pd.set_option('expand_frame_repr', False)
|
||||
# 最大列宽度
|
||||
pd.set_option('display.max_colwidth', None)
|
||||
|
||||
|
||||
def sleep():
|
||||
print("开始休眠,防止ip被拉黑或者进小黑屋")
|
||||
# 生成随机的休眠时间
|
||||
sleep_time = random.uniform(0, 2)
|
||||
# 执行休眠
|
||||
time.sleep(sleep_time)
|
||||
print("休眠结束!")
|
||||
|
||||
|
||||
def print_markdown(df: pd):
|
||||
print(tabulate(df, headers='keys', tablefmt='github'))
|
||||
|
||||
|
||||
def get_file(size):
|
||||
"""
|
||||
获取不同系统下的文件路径
|
||||
:param size:
|
||||
:return:
|
||||
"""
|
||||
sys_platform = platform.platform().lower()
|
||||
if "macos" in sys_platform:
|
||||
return f'/Users/renmeng/work_space/python_work/qnloft-get-web-everything/股票金融/量化交易/股票数据/{size}'
|
||||
elif "windows" in sys_platform:
|
||||
return f'E:\\Python_Workplace\qnloft-get-web-everything\\股票金融\\量化交易\\股票数据\\{size}'
|
||||
else:
|
||||
print("其他系统")
|
||||
|
||||
|
||||
def write_file(content, file_name, mode_type='a'):
|
||||
file = ""
|
||||
sys_platform = platform.platform().lower()
|
||||
if "macos" in sys_platform:
|
||||
file = f'/Users/renmeng/work_space/python_work/qnloft-get-web-everything/股票金融/量化交易/股票数据/{file_name}'
|
||||
elif "windows" in sys_platform:
|
||||
file = f'D:\\文档\\数据测试\\{file_name}'
|
||||
else:
|
||||
print("其他系统")
|
||||
path = Path(file)
|
||||
# 创建目录或文件
|
||||
if not path.exists():
|
||||
if path.is_dir():
|
||||
path.mkdir(parents=True) # 创建目录及其父目录
|
||||
else:
|
||||
path.touch() # 创建文件
|
||||
with path.open(mode=mode_type) as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def create_file(file_name):
|
||||
path = Path(file_name)
|
||||
# 创建目录或文件
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True) # 创建目录及其父目录
|
||||
|
||||
|
||||
def get_random_stock(size, n=0):
|
||||
if size == 'M':
|
||||
s_list = Path(get_file(size)).read_text().splitlines()
|
||||
elif size == 'S':
|
||||
s_list = Path(get_file(size)).read_text().splitlines()
|
||||
elif size == 'L':
|
||||
s_list = Path(get_file(size)).read_text().splitlines()
|
||||
else:
|
||||
raise ValueError("size输入错误!")
|
||||
if n > 0:
|
||||
return random.sample(s_list, n)
|
||||
else:
|
||||
return s_list
|
||||
|
||||
|
||||
def del_file_lines(size, to_delete):
|
||||
content = Path(get_file(size)).read_text()
|
||||
# 删除匹配的数字
|
||||
for char in to_delete:
|
||||
content = content.replace(char, '')
|
||||
|
||||
# 将更新后的内容写回文件
|
||||
with Path(get_file(size)).open(mode='w') as file:
|
||||
file.writelines(content)
|
||||
|
||||
|
||||
def if_run(current_date=datetime.now()):
|
||||
"""
|
||||
判断当前日期是否是交易日
|
||||
:param current_date:
|
||||
:return:
|
||||
"""
|
||||
df = ak.tool_trade_date_hist_sina()
|
||||
# 将日期列转换为 datetime 类型
|
||||
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||
date_now_hour = current_date.hour
|
||||
# 将日期格式化为 "yyyy-mm-dd" 形式
|
||||
formatted_date = current_date.strftime('%Y-%m-%d')
|
||||
# 检查 DataFrame 是否包含特定日期
|
||||
contains_target_date = (df['trade_date'] == formatted_date).any()
|
||||
print(f'当日日期是:{formatted_date} ,今日是否开盘:{contains_target_date}')
|
||||
if contains_target_date:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def return_trading_day(now_date, offset_date, format_type='%Y%m%d'):
|
||||
"""
|
||||
获取当前时间的 前n个交易日期,后n个交易日期
|
||||
:param format_type:
|
||||
:param offset_date:
|
||||
:param now_date:
|
||||
:return:
|
||||
"""
|
||||
df = ak.tool_trade_date_hist_sina()
|
||||
# 将日期列转换为 datetime 类型
|
||||
df['trade_date'] = pd.to_datetime(df['trade_date'])
|
||||
if isinstance(now_date, datetime):
|
||||
formatted_date = now_date.strftime('%Y-%m-%d')
|
||||
else:
|
||||
formatted_date = now_date
|
||||
selected_indexes = df[df['trade_date'] == formatted_date].index
|
||||
|
||||
return df.loc[selected_indexes + offset_date, 'trade_date'].dt.strftime(format_type).values[0]
|
||||
|
||||
|
||||
def upload_files_to_ftp(local_directory, remote_directory, file_name=None):
|
||||
"""
|
||||
ftp上传
|
||||
:param local_directory:
|
||||
:param remote_directory:
|
||||
:param file_name:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
ftp = FTP("qxu1142200198.my3w.com")
|
||||
ftp.login('qxu1142200198', '48XZ55MB')
|
||||
except error_perm as e:
|
||||
print(f"Network error: {e}")
|
||||
return False
|
||||
try:
|
||||
# 切换目录
|
||||
ftp.cwd(remote_directory)
|
||||
except ftplib.error_perm:
|
||||
print("目录不存在, 开始创建...")
|
||||
ftp.mkd(remote_directory)
|
||||
try:
|
||||
path = Path(local_directory)
|
||||
# 判断是文件夹还是文件
|
||||
if path.is_dir() and file_name is None:
|
||||
files = os.listdir(local_directory) # get list of files in local directory
|
||||
for file in files:
|
||||
if file.endswith(".html"): # upload only .txt files
|
||||
local_file = f"{local_directory}/{file}"
|
||||
remote_file = f"{remote_directory}/{file}"
|
||||
print(local_file, "----->>", remote_file)
|
||||
with open(local_file, 'rb') as f:
|
||||
try:
|
||||
ftp.delete(remote_file)
|
||||
except ftplib.error_perm:
|
||||
print("目录不存在, 或已经删除...")
|
||||
ftp.storbinary(f"STOR {remote_file}", f) # upload file to FTP server
|
||||
if file_name is not None and len(file_name) > 0:
|
||||
local_file = f"{local_directory}/{file_name}"
|
||||
remote_file = f"{remote_directory}/{file_name}"
|
||||
print(local_file, "----->>", remote_file)
|
||||
with open(local_file, 'rb') as f:
|
||||
try:
|
||||
ftp.delete(remote_file)
|
||||
except ftplib.error_perm:
|
||||
print("目录不存在, 或已经删除...")
|
||||
ftp.storbinary(f"STOR {remote_file}", f) # upload file to FTP server
|
||||
finally:
|
||||
ftp.quit()
|
||||
return True
|
||||
|
||||
|
||||
def back_testing(df, cond: str, n=3):
|
||||
"""
|
||||
回撤测试
|
||||
:param df:
|
||||
:param cond:
|
||||
:param n:
|
||||
:return:
|
||||
"""
|
||||
filtered_df = df.eval(cond)
|
||||
total, tomorrow_rise, tomorrow_fall = 0, 0, 0
|
||||
res_df = pd.DataFrame(
|
||||
columns=['code', '日期', '明天涨跌幅', 'n日最大涨幅', 'n日平均涨跌幅', 'n日最大回撤'])
|
||||
for index, row in df.iterrows():
|
||||
if filtered_df[index] and index > n:
|
||||
# print(f'{row["trade_date"]} --> {row["KDJ_D"]} --> {row["KDJ_J"]}')
|
||||
total += 1
|
||||
# 明天涨跌幅情况
|
||||
tomorrow_pre = df.iloc[-index:-index + 1]['pct_chg'].values[0].round(2)
|
||||
# n日 最大涨幅
|
||||
max_h = df.iloc[-index:-index + n]['pct_chg'].max().round(2)
|
||||
# n日 平均涨跌幅
|
||||
mean_h = df.iloc[-index:-index + n]['pct_chg'].mean().round(2)
|
||||
# n日 最大回撤
|
||||
h_max = df.iloc[-index:-index + n]['high'].max()
|
||||
l_min = df.iloc[-index:-index + n]['low'].min()
|
||||
max_ret = f"{(h_max - l_min) / l_min:.2f}"
|
||||
res_df.loc[total, 'code'], res_df.loc[total, '日期'] = row['ts_code'], row["trade_date"]
|
||||
res_df.loc[total, '明天涨跌幅'] = tomorrow_pre
|
||||
res_df.loc[total, 'n日最大涨幅'] = max_h
|
||||
res_df.loc[total, 'n日平均涨跌幅'] = mean_h
|
||||
res_df.loc[total, 'n日最大回撤'] = max_ret
|
||||
return res_df
|
||||
|
||||
|
||||
def buying_and_selling_decisions(df, cond_buy: str, cond_sell: str, cond_buy_read=None, cond_sell_read=None):
|
||||
"""
|
||||
买卖测试
|
||||
:param df:
|
||||
:param cond_buy:
|
||||
:param cond_sell:
|
||||
:param cond_buy_read:
|
||||
:param cond_sell_read:
|
||||
:return:
|
||||
"""
|
||||
total, buy, sell, total_chg, buy_date, flag_ready_b, flag_ready_s = 0, 0, 0, 0, 0, False, False
|
||||
res_df = pd.DataFrame(
|
||||
columns=['code', '买入日期', '卖出日期', '盈利', '持仓天数'])
|
||||
for index, row in df.iterrows():
|
||||
trade_date, c = row["trade_date"], row['close']
|
||||
if (eval(cond_buy_read) or cond_buy_read is None) and buy == 0:
|
||||
# 准备买入
|
||||
print(f"{trade_date} {cond_buy_read},准备买入!")
|
||||
flag_ready_b = True
|
||||
if eval(cond_buy) and buy == 0 and flag_ready_b:
|
||||
# 买入
|
||||
print(f"{trade_date} {cond_buy},买入操作!")
|
||||
buy, buy_date = c, trade_date
|
||||
flag_ready_b = False
|
||||
if (eval(cond_sell_read) or cond_sell_read is None) and buy > 0:
|
||||
# 准备卖出
|
||||
print(f"{trade_date} {cond_sell_read},准备卖出!")
|
||||
flag_ready_s = True
|
||||
if eval(cond_sell) and buy > 0 and flag_ready_s:
|
||||
print(f"{trade_date} {cond_sell},卖出操作!")
|
||||
total += 1
|
||||
# 卖出
|
||||
sell = c
|
||||
# 计算盈利
|
||||
chg = round(((sell - buy) / buy) * 100, 2)
|
||||
# 计算持仓天数
|
||||
start_date = datetime.strptime(buy_date, "%Y%m%d")
|
||||
end_date = datetime.strptime(trade_date, "%Y%m%d")
|
||||
# 计算间隔天数
|
||||
interval_days = (end_date - start_date).days
|
||||
buy, flag_ready_s = 0, False
|
||||
res_df.loc[total, 'code'], res_df.loc[total, '买入日期'] = row['ts_code'], buy_date
|
||||
res_df.loc[total, '卖出日期'] = trade_date
|
||||
res_df.loc[total, '盈利'] = chg
|
||||
res_df.loc[total, '持仓天数'] = interval_days
|
||||
return res_df
|
|
@ -0,0 +1,691 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def EMA(number, n):
|
||||
return pd.Series(number).ewm(alpha=2 / (n + 1), adjust=True).mean()
|
||||
|
||||
|
||||
def MA(number, n):
|
||||
return pd.Series.rolling(number, n).mean()
|
||||
|
||||
|
||||
def SMA(number, n, m=1):
|
||||
_df = number.fillna(0)
|
||||
return pd.Series(_df).ewm(com=n - m, adjust=True).mean()
|
||||
|
||||
|
||||
def RM_SMA(DF, N, M):
|
||||
DF = DF.fillna(0)
|
||||
z = len(DF)
|
||||
var = np.zeros(z)
|
||||
var[0] = DF[0]
|
||||
for i in range(1, z):
|
||||
var[i] = (DF[i] * M + var[i - 1] * (N - M)) / N
|
||||
for i in range(z):
|
||||
DF[i] = var[i]
|
||||
return DF
|
||||
|
||||
|
||||
def ATR(close, high, low, n):
|
||||
"""
|
||||
真实波幅
|
||||
:param close:
|
||||
:param high:
|
||||
:param low:
|
||||
:param n:
|
||||
:return:
|
||||
"""
|
||||
c, h, l_ = close, high, low
|
||||
mtr = MAX(MAX((h - l_), ABS(REF(c, 1) - h)), ABS(REF(c, 1) - l_))
|
||||
atr = MA(mtr, n)
|
||||
return pd.DataFrame({'MTR': mtr, 'ATR': atr})
|
||||
|
||||
|
||||
def HHV(number, n):
|
||||
return pd.Series.rolling(number, n).max()
|
||||
|
||||
|
||||
def LLV(number, n):
|
||||
return pd.Series.rolling(number, n).min()
|
||||
|
||||
|
||||
def SUM(number, n):
|
||||
return pd.Series.rolling(number, n).sum()
|
||||
|
||||
|
||||
def ABS(number):
|
||||
return np.abs(number)
|
||||
|
||||
|
||||
def MAX(A, B):
|
||||
return np.maximum(A, B)
|
||||
|
||||
|
||||
def MIN(A, B):
|
||||
var = IF(A < B, A, B)
|
||||
return var
|
||||
|
||||
|
||||
def IF(COND, V1, V2):
|
||||
var = np.flip(np.where(COND, V1, V2))
|
||||
return pd.Series(var)[::-1]
|
||||
|
||||
|
||||
def REF(DF, N):
|
||||
var = DF.diff(N)
|
||||
var = DF - var
|
||||
return var
|
||||
|
||||
|
||||
def STD(number, n):
|
||||
return pd.Series.rolling(number, n).std()
|
||||
|
||||
|
||||
def MACD(close, f, s, m):
|
||||
"""
|
||||
|
||||
:param close:
|
||||
:param f:
|
||||
:param s:
|
||||
:param m:
|
||||
:return:
|
||||
"""
|
||||
EMAFAST = EMA(close, f)
|
||||
EMASLOW = EMA(close, s)
|
||||
DIFF = EMAFAST - EMASLOW
|
||||
DEA = EMA(DIFF, m)
|
||||
MACD = (DIFF - DEA) * 2
|
||||
return pd.DataFrame({
|
||||
'DIFF': round(DIFF, 2),
|
||||
'DEA': round(DEA, 2), 'MACD': round(MACD, 2)})
|
||||
|
||||
|
||||
def KDJ(close, high, low, n, m1, m2):
|
||||
"""
|
||||
|
||||
:param close:
|
||||
:param high:
|
||||
:param low:
|
||||
:param n:
|
||||
:param m1:
|
||||
:param m2:
|
||||
:return:
|
||||
"""
|
||||
c, h, l = close, high, low
|
||||
RSV = (c - LLV(l, n)) / (HHV(h, n) - LLV(l, n)) * 100
|
||||
K = SMA(RSV, m1, 1)
|
||||
D = SMA(K, m2, 1)
|
||||
J = 3 * K - 2 * D
|
||||
return pd.DataFrame({'KDJ_K': round(K, 2), 'KDJ_D': round(D, 2), 'KDJ_J': round(J, 2)})
|
||||
|
||||
|
||||
def OSC(close, n, m):
|
||||
"""
|
||||
变动速率线
|
||||
:param close:
|
||||
:param n:
|
||||
:param m:
|
||||
:return:
|
||||
"""
|
||||
c = close
|
||||
OS = (c - MA(c, n)) * 100
|
||||
MAOSC = EMA(OS, m)
|
||||
return pd.DataFrame({'OSC': OS, 'MAOSC': MAOSC})
|
||||
|
||||
|
||||
def BBI(close, N1, N2, N3, N4):
|
||||
"""
|
||||
多空指标
|
||||
:param close:
|
||||
:param N1:
|
||||
:param N2:
|
||||
:param N3:
|
||||
:param N4:
|
||||
:return:
|
||||
"""
|
||||
bbi = (MA(close, N1) + MA(close, N2) + MA(close, N3) + MA(close, N4)) / 4
|
||||
return pd.DataFrame({'BBI': round(bbi, 2)})
|
||||
|
||||
|
||||
def BBIBOLL(close, n, m, n1=3, n2=6, n3=12, n4=24):
|
||||
"""
|
||||
多空布林线
|
||||
:param close:
|
||||
:param n1:
|
||||
:param n2:
|
||||
:param n3:
|
||||
:param n4:
|
||||
:param n:
|
||||
:param m:
|
||||
:return:
|
||||
"""
|
||||
bbi_boll = BBI(close, n1, n2, n3, n4)['BBI']
|
||||
UPER = bbi_boll + m * STD(bbi_boll, n)
|
||||
DOWN = bbi_boll - m * STD(bbi_boll, n)
|
||||
return pd.DataFrame({'BBIBOLL': round(bbi_boll, 2), 'UPER': round(UPER, 2), 'DOWN': round(DOWN, 2)})
|
||||
|
||||
|
||||
def PBX(close, n1, n2, n3, n4, n5, n6):
|
||||
"""
|
||||
瀑布线
|
||||
:param close:
|
||||
:param n1:
|
||||
:param n2:
|
||||
:param n3:
|
||||
:param n4:
|
||||
:param n5:
|
||||
:param n6:
|
||||
:return:
|
||||
"""
|
||||
c = close
|
||||
PBX1 = (EMA(c, n1) + MA(c, 2 * n1) + MA(c, 4 * n1)) / 3
|
||||
PBX2 = (EMA(c, n2) + MA(c, 2 * n2) + MA(c, 4 * n2)) / 3
|
||||
PBX3 = (EMA(c, n3) + MA(c, 2 * n3) + MA(c, 4 * n3)) / 3
|
||||
PBX4 = (EMA(c, n4) + MA(c, 2 * n4) + MA(c, 4 * n4)) / 3
|
||||
PBX5 = (EMA(c, n5) + MA(c, 2 * n5) + MA(c, 4 * n5)) / 3
|
||||
PBX6 = (EMA(c, n6) + MA(c, 2 * n6) + MA(c, 4 * n6)) / 3
|
||||
return pd.DataFrame(
|
||||
{'PBX1': round(PBX1, 2), 'PBX2': round(PBX2, 2), 'PBX3': round(PBX3, 2),
|
||||
'PBX4': round(PBX4, 2), 'PBX5': round(PBX5, 2), 'PBX6': round(PBX6, 2)}
|
||||
)
|
||||
|
||||
|
||||
def BOLL(close, N): # 布林线
|
||||
boll = MA(close, N)
|
||||
UB = boll + 2 * STD(close, N)
|
||||
LB = boll - 2 * STD(close, N)
|
||||
return pd.DataFrame({'BOLL': round(boll, 2), 'UB': round(UB, 2), 'LB': round(LB, 2)})
|
||||
|
||||
|
||||
def ROC(close, n, m):
|
||||
"""
|
||||
变动率指标
|
||||
:param close:
|
||||
:param n:
|
||||
:param m:
|
||||
:return:
|
||||
"""
|
||||
c = close
|
||||
roc = 100 * (c - REF(c, n)) / REF(c, n)
|
||||
maroc = MA(roc, m)
|
||||
return pd.DataFrame({'ROC': round(roc, 2), 'MAROC': round(maroc, 2)})
|
||||
|
||||
|
||||
def MTM(close, n, m):
|
||||
"""
|
||||
动量线
|
||||
:param close:
|
||||
:param n:
|
||||
:param m:
|
||||
:return:
|
||||
"""
|
||||
c = close
|
||||
mtm = c - REF(c, n)
|
||||
mtm_ma = MA(mtm, m)
|
||||
return pd.DataFrame({'MTM': round(mtm, 2), 'MTMMA': round(mtm_ma, 2)})
|
||||
|
||||
|
||||
def MFI(close, high, low, vol, n):
|
||||
"""
|
||||
资金指标
|
||||
:param close:
|
||||
:param high:
|
||||
:param low:
|
||||
:param vol:
|
||||
:param n:
|
||||
:return:
|
||||
"""
|
||||
c, h, l, v = close, high, low, vol
|
||||
TYP = (c + h + l) / 3
|
||||
V1 = SUM(IF(TYP > REF(TYP, 1), TYP * v, 0), n) / \
|
||||
SUM(IF(TYP < REF(TYP, 1), TYP * v, 0), n)
|
||||
mfi = 100 - (100 / (1 + V1))
|
||||
return pd.DataFrame({'MFI': round(mfi, 2)})
|
||||
|
||||
|
||||
def SKDJ(close, high, low, N, M):
|
||||
c = close
|
||||
LOWV = LLV(low, N)
|
||||
HIGHV = HHV(high, N)
|
||||
RSV = EMA((c - LOWV) / (HIGHV - LOWV) * 100, M)
|
||||
K = EMA(RSV, M)
|
||||
D = MA(K, M)
|
||||
return pd.DataFrame({'SKDJ_K': round(K, 2), 'SKDJ_D': round(D, 2)})
|
||||
|
||||
|
||||
def WR(close, high, low, N, N1):
|
||||
"""
|
||||
威廉指标
|
||||
:param close:
|
||||
:param high:
|
||||
:param low:
|
||||
:param N:
|
||||
:param N1:
|
||||
:return:
|
||||
"""
|
||||
c, h, l = close, high, low
|
||||
WR1 = round(100 * (HHV(h, N) - c) / (HHV(h, N) - LLV(l, N)), 2)
|
||||
WR2 = round(100 * (HHV(h, N1) - c) / (HHV(h, N1) - LLV(l, N1)), 2)
|
||||
return pd.DataFrame({'WR1': round(WR1, 2), 'WR2': round(WR2, 2)})
|
||||
|
||||
|
||||
def BIAS(DF, N1, N2, N3): # 乖离率
|
||||
CLOSE = DF
|
||||
BIAS1 = (CLOSE - MA(CLOSE, N1)) / MA(CLOSE, N1) * 100
|
||||
BIAS2 = (CLOSE - MA(CLOSE, N2)) / MA(CLOSE, N2) * 100
|
||||
BIAS3 = (CLOSE - MA(CLOSE, N3)) / MA(CLOSE, N3) * 100
|
||||
DICT = {'BIAS1': BIAS1, 'BIAS2': BIAS2, 'BIAS3': BIAS3}
|
||||
VAR = pd.DataFrame(DICT)
|
||||
return VAR
|
||||
|
||||
|
||||
def RSI(c, N1, N2, N3): # 相对强弱指标RSI1:SMA(MAX(CLOSE-LC,0),N1,1)/SMA(ABS(CLOSE-LC),N1,1)*100;
|
||||
DIF = c - REF(c, 1)
|
||||
RSI1 = round((SMA(MAX(DIF, 0), N1) / round(SMA(ABS(DIF), N1) * 100, 3)) * 10000, 2)
|
||||
RSI2 = round((SMA(MAX(DIF, 0), N2) / round(SMA(ABS(DIF), N2) * 100, 3)) * 10000, 2)
|
||||
RSI3 = round((SMA(MAX(DIF, 0), N3) / round(SMA(ABS(DIF), N3) * 100, 3)) * 10000, 2)
|
||||
return pd.DataFrame({'RSI1': RSI1, 'RSI2': RSI2, 'RSI3': RSI3})
|
||||
|
||||
|
||||
def ADTM(DF, N, M): # 动态买卖气指标
|
||||
HIGH = DF['high']
|
||||
LOW = DF['low']
|
||||
OPEN = DF['open']
|
||||
DTM = IF(OPEN <= REF(OPEN, 1), 0, MAX(
|
||||
(HIGH - OPEN), (OPEN - REF(OPEN, 1))))
|
||||
DBM = IF(OPEN >= REF(OPEN, 1), 0, MAX((OPEN - LOW), (OPEN - REF(OPEN, 1))))
|
||||
STM = SUM(DTM, N)
|
||||
SBM = SUM(DBM, N)
|
||||
ADTM1 = IF(STM > SBM, (STM - SBM) / STM,
|
||||
IF(STM == SBM, 0, (STM - SBM) / SBM))
|
||||
MAADTM = MA(ADTM1, M)
|
||||
DICT = {'ADTM': ADTM1, 'MAADTM': MAADTM}
|
||||
VAR = pd.DataFrame(DICT)
|
||||
return VAR
|
||||
|
||||
|
||||
def DDI(DF, N, N1, M, M1): # 方向标准离差指数
|
||||
H = DF['high']
|
||||
L = DF['low']
|
||||
DMZ = IF((H + L) <= (REF(H, 1) + REF(L, 1)), 0,
|
||||
MAX(ABS(H - REF(H, 1)), ABS(L - REF(L, 1))))
|
||||
DMF = IF((H + L) >= (REF(H, 1) + REF(L, 1)), 0,
|
||||
MAX(ABS(H - REF(H, 1)), ABS(L - REF(L, 1))))
|
||||
DIZ = SUM(DMZ, N) / (SUM(DMZ, N) + SUM(DMF, N))
|
||||
DIF = SUM(DMF, N) / (SUM(DMF, N) + SUM(DMZ, N))
|
||||
ddi = DIZ - DIF
|
||||
ADDI = SMA(ddi, N1, M)
|
||||
AD = MA(ADDI, M1)
|
||||
DICT = {'DDI': ddi, 'ADDI': ADDI, 'AD': AD}
|
||||
VAR = pd.DataFrame(DICT)
|
||||
return VAR
|
||||
|
||||
|
||||
ZIG_STATE_START = 0
|
||||
ZIG_STATE_RISE = 1
|
||||
ZIG_STATE_FALL = 2
|
||||
|
||||
|
||||
def ZIG(d, k, n):
|
||||
"""
|
||||
之字转向指标,当前价格变化超过 x% 时候变化
|
||||
:param d: 交易日期
|
||||
:param k: 价格
|
||||
:param n: 系数
|
||||
:return:
|
||||
"""
|
||||
x = round(n / 100, 2)
|
||||
peer_i = 0
|
||||
candidate_i = None
|
||||
scan_i = 0
|
||||
peers = [0]
|
||||
z = np.zeros(len(k))
|
||||
state = ZIG_STATE_START
|
||||
while True:
|
||||
scan_i += 1
|
||||
if scan_i == len(k) - 1:
|
||||
# 扫描到尾部
|
||||
if candidate_i is None:
|
||||
peer_i = scan_i
|
||||
peers.append(peer_i)
|
||||
else:
|
||||
if state == ZIG_STATE_RISE:
|
||||
if k[scan_i] >= k[candidate_i]:
|
||||
print(d[scan_i], "1 --->>>", d[candidate_i])
|
||||
peer_i = scan_i
|
||||
peers.append(peer_i)
|
||||
else:
|
||||
peer_i = candidate_i
|
||||
peers.append(peer_i)
|
||||
peer_i = scan_i
|
||||
peers.append(peer_i)
|
||||
elif state == ZIG_STATE_FALL:
|
||||
if k[scan_i] <= k[candidate_i]:
|
||||
print(d[scan_i], "2 --->>>", d[candidate_i])
|
||||
peer_i = scan_i
|
||||
peers.append(peer_i)
|
||||
else:
|
||||
peer_i = candidate_i
|
||||
peers.append(peer_i)
|
||||
peer_i = scan_i
|
||||
peers.append(peer_i)
|
||||
break
|
||||
if state == ZIG_STATE_START:
|
||||
if k[scan_i] >= k[peer_i] * (1 + x):
|
||||
print(d[scan_i], "3 --->>>", d[peer_i])
|
||||
candidate_i = scan_i
|
||||
state = ZIG_STATE_RISE
|
||||
elif k[scan_i] <= k[peer_i] * (1 - x):
|
||||
print(d[scan_i], "4 --->>>", d[peer_i])
|
||||
candidate_i = scan_i
|
||||
state = ZIG_STATE_FALL
|
||||
elif state == ZIG_STATE_RISE:
|
||||
if k[scan_i] >= k[candidate_i]:
|
||||
candidate_i = scan_i
|
||||
elif k[scan_i] <= k[candidate_i] * (1 - x):
|
||||
print(d[scan_i], "5 --->>>", d[candidate_i])
|
||||
peer_i = candidate_i
|
||||
peers.append(peer_i)
|
||||
state = ZIG_STATE_FALL
|
||||
candidate_i = scan_i
|
||||
elif state == ZIG_STATE_FALL:
|
||||
if k[scan_i] <= k[candidate_i]:
|
||||
print(d[scan_i], "6 --->>>", d[candidate_i])
|
||||
candidate_i = scan_i
|
||||
elif k[scan_i] >= k[candidate_i] * (1 + x):
|
||||
print(d[scan_i], "7 --->>>", d[candidate_i])
|
||||
peer_i = candidate_i
|
||||
peers.append(peer_i)
|
||||
state = ZIG_STATE_RISE
|
||||
candidate_i = scan_i
|
||||
for i in range(len(peers) - 1):
|
||||
peer_start_i = peers[i]
|
||||
peer_end_i = peers[i + 1]
|
||||
start_value = k[peer_start_i]
|
||||
end_value = k[peer_end_i]
|
||||
a = (end_value - start_value) / (peer_end_i - peer_start_i) # 斜率
|
||||
for j in range(peer_end_i - peer_start_i + 1):
|
||||
z[j + peer_start_i] = start_value + a * j
|
||||
return pd.Series(z), peers
|
||||
|
||||
|
||||
def TROUGHBARS(z, p, m):
|
||||
"""
|
||||
前 m 个 zig 波谷到当前的距离
|
||||
:param z: zig 指标
|
||||
:param p: zip 转折点
|
||||
:param m: 系数
|
||||
:return:
|
||||
"""
|
||||
trough_bars = np.zeros(len(z))
|
||||
if len(z) > 3:
|
||||
j = 1
|
||||
# 判断第一个是谷还是峰 ,峰则取偶数,如果是谷,则取奇数
|
||||
if z[0] > z[1]:
|
||||
# 第一个是波谷
|
||||
for i in range(len(p)):
|
||||
peer = p[i]
|
||||
j = i + m * 2
|
||||
if 0 < i and len(p) > j and i % 2 == 1:
|
||||
num = p[j] - peer - 1
|
||||
trough_bars[p[j] - 1] = num
|
||||
trough_bars[p[j]] = 1
|
||||
if z[0] < z[1]:
|
||||
# 第一个是波峰
|
||||
for i in range(len(p)):
|
||||
peer = p[i]
|
||||
j = i + m * 2
|
||||
if 0 < i and len(p) > j and i % 2 == 0:
|
||||
num = p[j] - peer - 1
|
||||
trough_bars[p[j] - 1] = num
|
||||
trough_bars[p[j]] = 1
|
||||
return pd.Series(trough_bars)
|
||||
|
||||
|
||||
def CROSS(a, b):
|
||||
"""
|
||||
穿越信号
|
||||
当a向上穿越b时,标记1;当a向下穿越b时,标记-1;没穿越标记0
|
||||
:param obj:
|
||||
:param ref:
|
||||
:return:
|
||||
"""
|
||||
assert len(a) == len(b), '穿越信号输入维度不相等'
|
||||
assert len(a) > 1, '穿越信号长度至少为2'
|
||||
res = np.zeros(len(a))
|
||||
for i in range(len(a) - 2, -1, -1):
|
||||
if a[i + 1] <= b[i + 1] and a[i] > b[i] and a[i + 1] < a[i]:
|
||||
# 向上穿越时,标记1
|
||||
res[i] = 1
|
||||
elif a[i + 1] >= b[i + 1] and a[i] < b[i] and a[i + 1] > a[i]:
|
||||
res[i] = -1
|
||||
else:
|
||||
res[i] = 0
|
||||
# print(f"a+1 = {a[i + 1]}, b+1 = {b[i + 1]} , a = {a[i]} , b = {b[i]} , res = {res[i]}")
|
||||
return pd.Series(res)
|
||||
|
||||
|
||||
def _calc_slope(x):
|
||||
return np.polyfit(range(len(x)), x, 1)[0]
|
||||
|
||||
|
||||
def rolling_window(a, window):
|
||||
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
|
||||
strides = a.values.strides + (a.values.strides[-1],)
|
||||
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
|
||||
|
||||
|
||||
def SLOPE(series, n):
|
||||
"""
|
||||
SLOPE(X,N) 返回线性回归斜率,N支持变量
|
||||
参考:https://blog.csdn.net/luhouxiang/article/details/113816062
|
||||
"""
|
||||
a = rolling_window(series, n)
|
||||
obj = np.array([_calc_slope(x) for x in a])
|
||||
new_obj = np.pad(obj, (len(series) - len(obj), 0), 'constant', constant_values=(np.nan, np.nan))
|
||||
return new_obj
|
||||
|
||||
|
||||
def GOLD_MACD(df_data: pd.DataFrame):
|
||||
"""
|
||||
黄金MACD指标
|
||||
:param df_data:
|
||||
:return:
|
||||
"""
|
||||
df = df_data[::-1].reset_index(drop=True)
|
||||
CLOSE = df["close"]
|
||||
d = df["trade_date"]
|
||||
MACD = (EMA(CLOSE, 30) - REF(EMA(CLOSE, 30), 1)) / REF(EMA(CLOSE, 30), 1) * 100
|
||||
DIF = EMA(SUM(MACD, 2), 5)
|
||||
buy_1 = DIF > REF(DIF, 1)
|
||||
buy_2 = DIF < REF(DIF, 1)
|
||||
|
||||
DEA = MA(DIF, 5)
|
||||
return pd.DataFrame(
|
||||
{'code': df['ts_code'], 'date': d, 'MACD': MACD, 'DIF': DIF, 'DEA': DEA, 'buy1': buy_1, "buy2": buy_2})
|
||||
|
||||
|
||||
def DJCPX(df_data: pd.DataFrame):
|
||||
"""
|
||||
顶级操盘线 指标
|
||||
:param df_data:
|
||||
:return:
|
||||
"""
|
||||
df = df_data[::-1].reset_index(drop=True)
|
||||
k = df["close"]
|
||||
d = df["trade_date"]
|
||||
# print(f'{d[i]} -->> {buy_1[i]} -->> {buy_2[i]} -->> {B[i]}')
|
||||
VAR_200 = round((100 - ((90 * (HHV(df["high"], 20) - df["close"])) / (
|
||||
HHV(df["high"], 20) - LLV(df["low"], 20)))), 2)
|
||||
VAR_300 = round((100 - MA(
|
||||
((100 * (HHV(df["high"], 5) - df["close"])) / (HHV(df["high"], 5) - LLV(df["low"], 5))),
|
||||
34)), 2)
|
||||
VAR_300_MA_5 = MA(VAR_300, 5)
|
||||
# F:IF(CROSS(VAR200,MA(VAR300,5)),LOW * 0.98,DRAWNULL),CROSSDOT,LINETHICK3,COLORFF00FF;
|
||||
# CROSS 上穿函数 CROSS(A,B)表示当A从下方向上穿过B时返回1,否则返回0
|
||||
F = np.zeros(df.shape[0])
|
||||
VAR_CROSS = CROSS(VAR_200, VAR_300_MA_5)
|
||||
for i in range(df.shape[0]):
|
||||
if VAR_CROSS[i] == 1:
|
||||
F[i] = round(df["low"][i] * 0.98, 2)
|
||||
# 重心:=(C+0.618*REF(C,1)+0.382*REF(C,1)+0.236*REF(C,3)+0.146*REF(C,4))/2.382;
|
||||
ZX = round((k + (0.618 * REF(k, 1)) + (0.382 * REF(k, 1)) + (0.236 * REF(k, 3)) + (
|
||||
0.146 * REF(k, 4))) / 2.382, 2)
|
||||
# 【操盘线】:EMA(((SLOPE(C,22)*20)+C),55),COLORYELLOW,LINETHICK4;
|
||||
CPX = round(EMA(((SLOPE(k, 22) * 20) + k), 55), 2)
|
||||
# 【黄金线】:IF(重心>=【操盘线】,【操盘线】,DRAWNULL),COLORRED,LINETHICK2;
|
||||
HJX = np.zeros(df.shape[0])
|
||||
# 【空仓线】:IF(重心<【操盘线】,【操盘线】,DRAWNULL),COLORCYAN,LINETHICK2;
|
||||
KCX = np.zeros(df.shape[0])
|
||||
for i in range(df.shape[0]):
|
||||
if ZX[i] >= CPX[i]:
|
||||
HJX[i] = CPX[i]
|
||||
else:
|
||||
KCX[i] = CPX[i]
|
||||
return pd.DataFrame(
|
||||
{'code': df['ts_code'], 'date': d, 'F': F, '黄金线': HJX, '空仓线': KCX})
|
||||
|
||||
|
||||
def CCI(DF, n: int = 14):
|
||||
TP = (DF['low'] + DF['high'] + DF['close']) / 3
|
||||
MA = TP.rolling(window=n).mean()
|
||||
MD = TP.rolling(window=n).apply(lambda x: abs(x - x.mean()).mean(), raw=False)
|
||||
return round((TP - MA) / (0.015 * MD), 2)
|
||||
|
||||
|
||||
def bullish(DF, N):
|
||||
"""
|
||||
多头指标,N项的递增序列
|
||||
:param DF:
|
||||
:param N:
|
||||
:return:
|
||||
"""
|
||||
return pd.Series.rolling(DF, N).apply(lambda x: x.is_monotonic_increasing)
|
||||
|
||||
|
||||
def bearish(DF, N):
|
||||
"""
|
||||
空头指标,N项的递减序列
|
||||
:param DF:
|
||||
:param N:
|
||||
:return:
|
||||
"""
|
||||
return pd.Series.rolling(DF, N).apply(lambda x: x.is_monotonic_decreasing)
|
||||
|
||||
|
||||
def OBV(c, v, M):
|
||||
# diff 计算相邻元素的差值
|
||||
change = np.diff(c)
|
||||
# sign 用于获取数组元素的符号的函数,对于正数,返回 1,对于负数,返回 -1,对于零,返回 0
|
||||
# hstack 用于水平(按列)连接数组的函数
|
||||
sig = np.hstack([[1], np.sign(change)])
|
||||
# cumsum 计算累积和的方法。它将给定数组中的元素逐个累加
|
||||
obv_ = np.cumsum(v * sig)
|
||||
OBV = pd.Series(obv_)
|
||||
MAOBV = MA(OBV, M)
|
||||
return pd.DataFrame({'OBV': OBV, 'MAOBV': MAOBV})
|
||||
|
||||
|
||||
def OBV_PLUS(DF, M):
|
||||
"""
|
||||
OBV策略升级:TODO 还没完成
|
||||
1. 增加价格相距大的那一天成交量的权重,这可以更突出上升趋势和下降趋势。
|
||||
2. 当天的成交量以一定比例加入OBV中,而不是将全天的成交量全部加入OBV中。
|
||||
:param DF:
|
||||
:param M:
|
||||
:return:
|
||||
"""
|
||||
CLOSE = DF['close']
|
||||
VOL = DF['vol']
|
||||
ref = REF(CLOSE, 1)
|
||||
var_total = 0
|
||||
for index, row in DF.iterrows():
|
||||
if np.isnan(ref[index]):
|
||||
var_total += VOL[index]
|
||||
continue
|
||||
if row["close"] > ref[index]:
|
||||
vol = VOL[index] * 1
|
||||
elif row["close"] == ref[index]:
|
||||
vol = 0
|
||||
else:
|
||||
vol = VOL[index] * -1
|
||||
var_total += vol
|
||||
|
||||
|
||||
def ASI(OPEN, CLOSE, HIGH, LOW, M1=26, M2=10):
|
||||
"""
|
||||
# 振动升降指标
|
||||
:param OPEN:
|
||||
:param CLOSE:
|
||||
:param HIGH:
|
||||
:param LOW:
|
||||
:param M1:
|
||||
:param M2:
|
||||
:return:
|
||||
"""
|
||||
LC = REF(CLOSE, 1)
|
||||
AA = ABS(HIGH - LC)
|
||||
BB = ABS(LOW - LC)
|
||||
CC = ABS(HIGH - REF(LOW, 1))
|
||||
DD = ABS(LC - REF(OPEN, 1))
|
||||
R = IF((AA > BB) & (AA > CC), AA + BB / 2 + DD / 4, IF((BB > CC) & (BB > AA), BB + AA / 2 + DD / 4, CC + DD / 4))
|
||||
X = (CLOSE - LC + (CLOSE - OPEN) / 2 + LC - REF(OPEN, 1))
|
||||
SI = 16 * X / R * MAX(AA, BB)
|
||||
ASI = SUM(SI, M1)
|
||||
ASIT = MA(ASI, M2)
|
||||
return {'ASI': ASI, 'ASIT': ASIT}
|
||||
|
||||
|
||||
def DMI(CLOSE, HIGH, LOW, M1=14, M2=6): # 动向指标:结果和同花顺,通达信完全一致
|
||||
TR = SUM(MAX(MAX(HIGH - LOW, ABS(HIGH - REF(CLOSE, 1))), ABS(LOW - REF(CLOSE, 1))), M1)
|
||||
HD = HIGH - REF(HIGH, 1)
|
||||
LD = REF(LOW, 1) - LOW
|
||||
DMP = SUM(IF((HD > 0) & (HD > LD), HD, 0), M1)
|
||||
DMM = SUM(IF((LD > 0) & (LD > HD), LD, 0), M1)
|
||||
PDI = (DMP * 100) / TR
|
||||
MDI = (DMM * 100) / TR
|
||||
ADX = MA(ABS(MDI - PDI) / (PDI + MDI) * 100, M2)
|
||||
ADXR = (ADX + REF(ADX, M2)) / 2
|
||||
return {'PDI': round(PDI.fillna(0), 2), 'MDI': round(MDI.fillna(0), 2), 'ADX': round(ADX.fillna(0), 2),
|
||||
'ADXR': round(ADXR.fillna(0), 2)}
|
||||
|
||||
|
||||
def RM_KDJ(C, H, L, N, M1, M2):
|
||||
RSV = (C - LLV(L, N)) / (HHV(H, N) - LLV(L, N)) * 100
|
||||
K = RM_SMA(RSV, M1, 1)
|
||||
D = RM_SMA(K, M2, 1)
|
||||
J = 3 * K - 2 * D
|
||||
return pd.DataFrame({'KDJ_K': round(K, 2), 'KDJ_D': round(D, 2), 'KDJ_J': round(J, 2)})
|
||||
|
||||
|
||||
def INTPART(number):
|
||||
number = number.fillna(0)
|
||||
return number.astype(int)
|
||||
|
||||
|
||||
def JXNH(CLOSE, OPEN, VOL):
|
||||
VAR1 = MA(CLOSE, 5)
|
||||
VAR2 = MA(CLOSE, 10)
|
||||
VAR3 = MA(CLOSE, 30)
|
||||
VARB = SUM(CLOSE * VOL * 100, 28) / SUM(VOL * 100, 28)
|
||||
VARC = INTPART(VARB * 100) / 100
|
||||
VARD = EMA(CLOSE, 5) - EMA(CLOSE, 10)
|
||||
VARE = EMA(VARD, 9)
|
||||
VAR13 = REF(VARE, 1)
|
||||
VAR14 = VARE
|
||||
VAR15 = VAR14 - VAR13
|
||||
VAR16 = REF(VARD, 1)
|
||||
VAR17 = VARD
|
||||
VAR18 = VAR17 - VAR16
|
||||
VAR19 = OPEN
|
||||
VAR1A = CLOSE
|
||||
JXNH = (VAR19 <= VAR1) & \
|
||||
(VAR19 <= VAR2) & \
|
||||
(VAR19 <= VAR3) & \
|
||||
(VAR1A >= VAR1) & \
|
||||
(VAR1A >= VARC) & (VAR15 > 0) & (VAR18 > 0)
|
||||
return pd.DataFrame({'JXNH': JXNH})
|
|
@ -0,0 +1,176 @@
|
|||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
"""
|
||||
通达信操作工具类
|
||||
"""
|
||||
|
||||
|
||||
class TdxUtil:
|
||||
SECURITY_EXCHANGE = ["sz", "sh"]
|
||||
SECURITY_COEFFICIENT = {"SH_A_STOCK": [0.01, 0.01], "SH_B_STOCK": [0.001, 0.01], "SH_INDEX": [0.01, 1.0],
|
||||
"SH_FUND": [0.001, 1.0], "SH_BOND": [0.001, 1.0], "SZ_A_STOCK": [0.01, 0.01],
|
||||
"SZ_B_STOCK": [0.01, 0.01], "SZ_INDEX": [0.01, 1.0], "SZ_FUND": [0.001, 0.01],
|
||||
"SZ_BOND": [0.001, 0.01]}
|
||||
SECURITY_TYPE = ["SH_A_STOCK", "SH_B_STOCK", "SH_INDEX", "SH_FUND", "SH_BOND", "SZ_A_STOCK", "SZ_B_STOCK",
|
||||
"SZ_INDEX", "SZ_FUND", "SZ_BOND"]
|
||||
|
||||
def __init__(self, tdx_path: str):
|
||||
self.f_name = None
|
||||
self.path = tdx_path
|
||||
# 通达信自选股路径
|
||||
self.zxg_path = tdx_path + "\\T0002\\blocknew\\ZXG.blk"
|
||||
# 上证日线
|
||||
self.lday_sz_path = tdx_path + "\\vipdoc\\sz\\lday\\"
|
||||
# 沪深日线
|
||||
self.lday_sh_path = tdx_path + "\\vipdoc\\sh\\lday\\"
|
||||
|
||||
def set_zxg_file(self, cont: pd):
|
||||
"""
|
||||
通达信自选股写入
|
||||
:param cont:
|
||||
:return:
|
||||
"""
|
||||
# 将 DataFrame 内容逐行写入文本文件
|
||||
with open(self.zxg_path, "w") as file:
|
||||
for item in cont:
|
||||
stock_code = str(item)
|
||||
code = stock_code.split(".")[0]
|
||||
if len(code) > 1:
|
||||
if stock_code.startswith(('50', '51', '60', '688', '73', '90', '110', '113', '132', '204', '78')):
|
||||
code = '1' + stock_code
|
||||
if stock_code.startswith(('00', '13', '18', '15', '16', '18', '20', '30', '39', '115', '1318')):
|
||||
code = '0' + stock_code
|
||||
file.write(code + '\n')
|
||||
# 关闭文件
|
||||
file.close()
|
||||
|
||||
def get_zxg_file(self):
|
||||
"""
|
||||
读取通达信自选股
|
||||
:return:
|
||||
"""
|
||||
f = open(self.zxg_path, 'r')
|
||||
z = f.read()
|
||||
f.close()
|
||||
return z
|
||||
|
||||
def get_tdx_stock_data(self, code=None, freq=None, market=None):
|
||||
"""
|
||||
获取K线数据(日线,1分钟线,5分钟线,30分钟线)
|
||||
:param code: 股票代码
|
||||
:param freq: 分钟数据,1:代表1分钟,5:代表5分钟
|
||||
:param market: 上海证券交易所: sh, 深证证券交易所: sz, 北京证券交易所: bj
|
||||
:return:
|
||||
"""
|
||||
# 获取全部的数据
|
||||
if code is None:
|
||||
return self.get_all_stock()
|
||||
if freq is None:
|
||||
return self.get_day_k(code, market)
|
||||
else:
|
||||
return self.get_freq_k(code, market, freq)
|
||||
|
||||
def get_all_stock(self):
|
||||
"""
|
||||
获取全部股票
|
||||
:return:
|
||||
"""
|
||||
codes = []
|
||||
for market in self.SECURITY_EXCHANGE:
|
||||
files = Path(f'\\vipdoc\\{market}\\lday\\')
|
||||
for file in files.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
tdx_code = file.name.split(".")[0]
|
||||
market = tdx_code[:2]
|
||||
code = tdx_code[2:]
|
||||
res = self.get_security_type(code=code, market=market)
|
||||
filter_stock = ["SZ_A_STOCK", "SH_A_STOCK"]
|
||||
if res not in filter_stock:
|
||||
continue
|
||||
codes.append(code)
|
||||
return codes
|
||||
|
||||
def get_day_k(self, code, market):
|
||||
security_type = self.get_security_type(code, market)
|
||||
if security_type not in self.SECURITY_TYPE:
|
||||
print("Unknown security type !\n")
|
||||
raise NotImplementedError
|
||||
if security_type == "SZ_A_STOCK":
|
||||
self.f_name = f"{self.lday_sz_path}sz{code}.day"
|
||||
if security_type == "SH_A_STOCK":
|
||||
self.f_name = f"{self.lday_sh_path}sh{code}.day"
|
||||
print(self.f_name)
|
||||
df = pd.DataFrame(columns=['trade_date', 'open', 'high', 'low', 'close', 'amount', 'volume'])
|
||||
file_path = Path(self.f_name)
|
||||
coefficient = [0.01, 0.01]
|
||||
if not file_path.is_file():
|
||||
raise f"{self.f_name} 不是一个文件!"
|
||||
content = file_path.read_bytes()
|
||||
record_struct = struct.Struct('<IIIIIfII')
|
||||
print(len(content), record_struct.size)
|
||||
for offset in range(0, len(content), record_struct.size):
|
||||
row = record_struct.unpack_from(content, offset)
|
||||
t_date = str(row[0])
|
||||
# datestr = t_date[:4] + "-" + t_date[4:6] + "-" + t_date[6:]
|
||||
close = row[4] * coefficient[0]
|
||||
open = row[1] * coefficient[0]
|
||||
new_row = [
|
||||
t_date,
|
||||
open,
|
||||
row[2] * coefficient[0],
|
||||
row[3] * coefficient[0],
|
||||
close,
|
||||
row[5],
|
||||
row[6] * coefficient[1],
|
||||
]
|
||||
df.loc[len(df)] = new_row
|
||||
return df
|
||||
|
||||
def get_freq_k(self, code, market, freq):
|
||||
pass
|
||||
|
||||
def get_security_type(self, code, name=None, market=None):
|
||||
exchange = str(market).lower()
|
||||
code_head = code[:2]
|
||||
if name is not None:
|
||||
if 'ST' in name or '*' in name:
|
||||
return None
|
||||
if exchange == self.SECURITY_EXCHANGE[0]:
|
||||
if code_head in ["00", "30"]:
|
||||
return "SZ_A_STOCK"
|
||||
elif code_head in ["20"]:
|
||||
return "SZ_B_STOCK"
|
||||
elif code_head in ["39"]:
|
||||
return "SZ_INDEX"
|
||||
elif code_head in ["15", "16"]:
|
||||
return "SZ_FUND"
|
||||
elif code_head in ["10", "11", "12", "13", "14"]:
|
||||
return "SZ_BOND"
|
||||
elif exchange == self.SECURITY_EXCHANGE[1]:
|
||||
if code_head in ["60", "68"]: # 688XXX科创板
|
||||
return "SH_A_STOCK"
|
||||
elif code_head in ["90"]:
|
||||
return "SH_B_STOCK"
|
||||
elif code_head in ["00", "88", "99"]:
|
||||
return "SH_INDEX"
|
||||
elif code_head in ["50", "51"]:
|
||||
return "SH_FUND"
|
||||
elif code_head in ["01", "10", "11", "12", "13", "14", "20"]:
|
||||
return "SH_BOND"
|
||||
else:
|
||||
# 如果没有标识,只返回A股 TODO 不要创业板
|
||||
if code_head in ["00"]:
|
||||
return "SZ_A_STOCK"
|
||||
elif code_head in ["60"]:
|
||||
return "SH_A_STOCK"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tdx = TdxUtil("D:\\new_tdx")
|
||||
print(tdx.get_tdx_stock_data("002448"))
|
Loading…
Reference in New Issue