qnloft-stock/DB/db_main.py

272 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import traceback
import pandas
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker, scoped_session, aliased, declarative_base, clear_mappers
from sqlalchemy import *
from DB import db_config as config
class DbMain:
def __init__(self):
# clear_mappers()
self.inspector = None
self.session = None
self.engine = None
def get_session(self):
# 绑定引擎
session_factory = sessionmaker(bind=self.engine)
# 创建数据库链接池直接使用session即可为当前线程拿出一个链接对象conn
# 内部会采用threading.local进行隔离
Session = scoped_session(session_factory)
return Session()
# ===================== insert or update方法=============================
def insert_all_entry(self, entries):
try:
self.create_table(entries)
self.session.add_all(entries)
self.session.commit()
except Exception as e:
print(e)
finally:
self.session.close()
def insert_entry(self, entry):
try:
self.create_table(entry)
self.session.add(entry)
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误")
finally:
self.session.close()
def insert_or_update(self, entry, query_conditions):
"""
insert_or_update 的是用需要在 对象中新增to_dict方法将需要更新的字段转成字典
:param entry:
:param query_conditions:
:return:
"""
try:
self.create_table(entry)
conditions = " AND ".join(f"{key} = '{value}'" for key, value in query_conditions.items())
select_sql = text(f"select count(1) from {entry.__tablename__} where 1=1 and {conditions}")
result = self.session.execute(select_sql)
# 如果查询有结果则执行update操作
if result.scalar() > 0:
if hasattr(entry, 'to_dict'):
formatted_attributes = ", ".join([f'{attr} = "{value}"' for attr, value in entry.to_dict().items()])
update_sql = text(f"UPDATE {entry.__tablename__} SET {formatted_attributes} WHERE {conditions}")
self.session.execute(update_sql)
else:
raise Exception("对象 不包含 to_dict 方法,请添加!")
else:
# 执行新增错做
self.insert_entry(entry)
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.expire_all()
self.session.close()
def create_table(self, entries):
if isinstance(entries, list):
table_name = entries[0].__tablename__
else:
table_name = entries.__tablename__
# 检查表是否存在,如果不存在则创建
if not self.inspector.has_table(table_name):
if isinstance(entries, list):
entries[0].metadata.create_all(self.engine)
else:
entries.metadata.create_all(self.engine)
def pandas_insert(self, data=pandas, table_name=""):
"""
新增数据操作类型是pandas
:param data:
:param table_name:
:return:
"""
try:
data.to_sql(table_name, con=self.engine.connect(), if_exists='append', index=True, index_label='id')
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
# ===================== select 方法=============================
def query_by_id(self, model, db_id):
try:
return self.session.query(model).filter_by(id=db_id).all()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def pandas_query_by_model(self, model, order_col=None, page_number=None, page_size=None):
"""
使用pandas的sql引擎执行
:param page_size:每页的记录数
:param page_number:第N页
:param model:
:param order_col:
:return:
"""
try:
# 判断表是否存在
if self.has_table(model):
query = self.session.query(model)
if order_col is not None:
query = query.order_by(order_col)
if page_number is not None and page_size is not None:
offset = (page_number - 1) * page_size
query = query.offset(offset).limit(page_size)
return pandas.read_sql_query(query.statement, self.engine.connect())
return pandas.DataFrame()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
def pandas_query_by_sql(self, stmt=""):
"""
使用pandas的sql引擎执行
:param stmt:
:return:
"""
try:
return pandas.read_sql_query(sql=stmt, con=self.engine.connect())
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
self.engine.dispose()
def pandas_query_by_condition(self, model, query_condition):
try:
# 当需要根据多个条件进行查询操作时
# query_condition = and_(
# StockDaily.trade_date == '20230823',
# StockDaily.symbol == 'ABC'
# )
query = self.session.query(model).filter(query_condition).order_by(model.id)
return self.pandas_query_by_sql(stmt=query.statement).reset_index()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
# ===================== delete 方法=============================
def delete_by_id(self, model, db_id):
try:
# 使用 delete() 方法删除符合条件的记录
self.session.query(model).filter_by(id=db_id).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def delete_by_condition(self, model, delete_condition):
try:
# 使用 delete() 方法删除符合条件的记录
# 定义要删除的记录的条件
# 例如,假设你要删除 trade_date 为 '20230823' 的记录
# delete_condition = StockDaily.trade_date == '20230823'
# 当需要根据多个条件进行删除操作时
# delete_condition = and_(
# StockDaily.trade_date == '20230823',
# StockDaily.symbol == 'ABC'
# )
self.session.query(model).filter(delete_condition).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def delete_all_table(self, model): # 清空表数据
try:
self.session.query(model).delete()
self.session.commit()
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
# ===================== 其它 方法=============================
def has_table(self, entries):
if isinstance(entries, list):
table_name = entries[0].__tablename__
else:
table_name = entries.__tablename__
# 检查表是否存在,如果不存在则创建
return self.inspector.has_table(table_name)
def execute_sql(self, s):
try:
sql_text = text(s)
return self.session.execute(sql_text)
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def execute_sql_to_pandas(self, s):
try:
sql_text = text(s)
res = self.session.execute(sql_text)
return pandas.DataFrame(res.fetchall(), columns=res.keys())
except Exception as e:
trace = traceback.extract_tb(e.__traceback__)
for filename, lineno, funcname, source in trace:
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
f"错误内容:{traceback.format_exc()}")
finally:
self.session.close()
def close(self):
self.session.close()