qnloft-stock/DB/db_main.py

272 lines
9.6 KiB
Python
Raw Permalink Normal View History

2023-11-24 02:49:22 +00:00
import json
import traceback
import pandas
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker, scoped_session, aliased, declarative_base, clear_mappers
from sqlalchemy import *
from DB import db_config as config
class DbMain:
def __init__(self):
# clear_mappers()
self.inspector = None
self.session = None
self.engine = None
def get_session(self):
# 绑定引擎
session_factory = sessionmaker(bind=self.engine)
# 创建数据库链接池直接使用session即可为当前线程拿出一个链接对象conn
# 内部会采用threading.local进行隔离
Session = scoped_session(session_factory)
return Session()
# ===================== insert or update方法=============================
def insert_all_entry(self, entries):
try:
self.create_table(entries)
self.session.add_all(entries)
self.session.commit()
except Exception as e:
print(e)
finally:
self.session.close()
def insert_entry(self, entry):
try:
self.create_table(entry)
self.session.add(entry)
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误")
finally:
self.session.close()
def insert_or_update(self, entry, query_conditions):
"""
insert_or_update 的是用需要在 对象中新增to_dict方法将需要更新的字段转成字典
:param entry:
:param query_conditions:
:return:
"""
try:
self.create_table(entry)
conditions = " AND ".join(f"{key} = '{value}'" for key, value in query_conditions.items())
select_sql = text(f"select count(1) from {entry.__tablename__} where 1=1 and {conditions}")
result = self.session.execute(select_sql)
# 如果查询有结果则执行update操作
if result.scalar() > 0:
if hasattr(entry, 'to_dict'):
formatted_attributes = ", ".join([f'{attr} = "{value}"' for attr, value in entry.to_dict().items()])
update_sql = text(f"UPDATE {entry.__tablename__} SET {formatted_attributes} WHERE {conditions}")
self.session.execute(update_sql)
else:
raise Exception("对象 不包含 to_dict 方法,请添加!")
else:
# 执行新增错做
self.insert_entry(entry)
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.expire_all()
self.session.close()
def create_table(self, entries):
if isinstance(entries, list):
table_name = entries[0].__tablename__
else:
table_name = entries.__tablename__
# 检查表是否存在,如果不存在则创建
if not self.inspector.has_table(table_name):
if isinstance(entries, list):
entries[0].metadata.create_all(self.engine)
else:
entries.metadata.create_all(self.engine)
def pandas_insert(self, data=pandas, table_name=""):
"""
新增数据操作类型是pandas
:param data:
:param table_name:
:return:
"""
try:
data.to_sql(table_name, con=self.engine.connect(), if_exists='append', index=True, index_label='id')
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
# ===================== select 方法=============================
def query_by_id(self, model, db_id):
try:
return self.session.query(model).filter_by(id=db_id).all()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def pandas_query_by_model(self, model, order_col=None, page_number=None, page_size=None):
"""
使用pandas的sql引擎执行
:param page_size:每页的记录数
:param page_number:第N页
:param model:
:param order_col:
:return:
"""
try:
# 判断表是否存在
if self.has_table(model):
query = self.session.query(model)
if order_col is not None:
query = query.order_by(order_col)
if page_number is not None and page_size is not None:
offset = (page_number - 1) * page_size
query = query.offset(offset).limit(page_size)
return pandas.read_sql_query(query.statement, self.engine.connect())
return pandas.DataFrame()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
def pandas_query_by_sql(self, stmt=""):
"""
使用pandas的sql引擎执行
:param stmt:
:return:
"""
try:
return pandas.read_sql_query(sql=stmt, con=self.engine.connect())
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
def pandas_query_by_condition(self, model, query_condition):
try:
# 当需要根据多个条件进行查询操作时
# query_condition = and_(
# StockDaily.trade_date == '20230823',
# StockDaily.symbol == 'ABC'
# )
query = self.session.query(model).filter(query_condition).order_by(model.id)
return self.pandas_query_by_sql(stmt=query.statement).reset_index()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
# ===================== delete 方法=============================
def delete_by_id(self, model, db_id):
try:
# 使用 delete() 方法删除符合条件的记录
self.session.query(model).filter_by(id=db_id).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def delete_by_condition(self, model, delete_condition):
try:
# 使用 delete() 方法删除符合条件的记录
# 定义要删除的记录的条件
# 例如,假设你要删除 trade_date 为 '20230823' 的记录
# delete_condition = StockDaily.trade_date == '20230823'
# 当需要根据多个条件进行删除操作时
# delete_condition = and_(
# StockDaily.trade_date == '20230823',
# StockDaily.symbol == 'ABC'
# )
self.session.query(model).filter(delete_condition).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def delete_all_table(self, model): # 清空表数据
try:
self.session.query(model).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
# ===================== 其它 方法=============================
def has_table(self, entries):
if isinstance(entries, list):
table_name = entries[0].__tablename__
else:
table_name = entries.__tablename__
# 检查表是否存在,如果不存在则创建
return self.inspector.has_table(table_name)
def execute_sql(self, s):
try:
sql_text = text(s)
return self.session.execute(sql_text)
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def execute_sql_to_pandas(self, s):
try:
sql_text = text(s)
res = self.session.execute(sql_text)
return pandas.DataFrame(res.fetchall(), columns=res.keys())
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def close(self):
self.session.close()