272 lines
9.6 KiB
Python
272 lines
9.6 KiB
Python
|
import json
|
|||
|
import traceback
|
|||
|
|
|||
|
import pandas
|
|||
|
from sqlalchemy.exc import SQLAlchemyError
|
|||
|
from sqlalchemy.orm import sessionmaker, scoped_session, aliased, declarative_base, clear_mappers
|
|||
|
from sqlalchemy import *
|
|||
|
|
|||
|
from DB import db_config as config
|
|||
|
|
|||
|
|
|||
|
class DbMain:
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
# clear_mappers()
|
|||
|
self.inspector = None
|
|||
|
self.session = None
|
|||
|
self.engine = None
|
|||
|
|
|||
|
def get_session(self):
|
|||
|
# 绑定引擎
|
|||
|
session_factory = sessionmaker(bind=self.engine)
|
|||
|
# 创建数据库链接池,直接使用session即可为当前线程拿出一个链接对象conn
|
|||
|
# 内部会采用threading.local进行隔离
|
|||
|
Session = scoped_session(session_factory)
|
|||
|
return Session()
|
|||
|
|
|||
|
# ===================== insert or update方法=============================
|
|||
|
def insert_all_entry(self, entries):
|
|||
|
try:
|
|||
|
self.create_table(entries)
|
|||
|
self.session.add_all(entries)
|
|||
|
self.session.commit()
|
|||
|
except Exception as e:
|
|||
|
print(e)
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def insert_entry(self, entry):
|
|||
|
try:
|
|||
|
self.create_table(entry)
|
|||
|
self.session.add(entry)
|
|||
|
self.session.commit()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def insert_or_update(self, entry, query_conditions):
|
|||
|
"""
|
|||
|
insert_or_update 的是用需要在 对象中新增to_dict方法,将需要更新的字段转成字典
|
|||
|
:param entry:
|
|||
|
:param query_conditions:
|
|||
|
:return:
|
|||
|
"""
|
|||
|
try:
|
|||
|
self.create_table(entry)
|
|||
|
conditions = " AND ".join(f"{key} = '{value}'" for key, value in query_conditions.items())
|
|||
|
select_sql = text(f"select count(1) from {entry.__tablename__} where 1=1 and {conditions}")
|
|||
|
result = self.session.execute(select_sql)
|
|||
|
# 如果查询有结果,则执行update操作
|
|||
|
if result.scalar() > 0:
|
|||
|
if hasattr(entry, 'to_dict'):
|
|||
|
formatted_attributes = ", ".join([f'{attr} = "{value}"' for attr, value in entry.to_dict().items()])
|
|||
|
update_sql = text(f"UPDATE {entry.__tablename__} SET {formatted_attributes} WHERE {conditions}")
|
|||
|
self.session.execute(update_sql)
|
|||
|
else:
|
|||
|
raise Exception("对象 不包含 to_dict 方法,请添加!")
|
|||
|
else:
|
|||
|
# 执行新增错做
|
|||
|
self.insert_entry(entry)
|
|||
|
self.session.commit()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.expire_all()
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def create_table(self, entries):
|
|||
|
if isinstance(entries, list):
|
|||
|
table_name = entries[0].__tablename__
|
|||
|
else:
|
|||
|
table_name = entries.__tablename__
|
|||
|
# 检查表是否存在,如果不存在则创建
|
|||
|
if not self.inspector.has_table(table_name):
|
|||
|
if isinstance(entries, list):
|
|||
|
entries[0].metadata.create_all(self.engine)
|
|||
|
else:
|
|||
|
entries.metadata.create_all(self.engine)
|
|||
|
|
|||
|
def pandas_insert(self, data=pandas, table_name=""):
|
|||
|
"""
|
|||
|
新增数据操作,类型是pandas
|
|||
|
:param data:
|
|||
|
:param table_name:
|
|||
|
:return:
|
|||
|
"""
|
|||
|
try:
|
|||
|
data.to_sql(table_name, con=self.engine.connect(), if_exists='append', index=True, index_label='id')
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
self.engine.dispose()
|
|||
|
|
|||
|
# ===================== select 方法=============================
|
|||
|
def query_by_id(self, model, db_id):
|
|||
|
try:
|
|||
|
return self.session.query(model).filter_by(id=db_id).all()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def pandas_query_by_model(self, model, order_col=None, page_number=None, page_size=None):
|
|||
|
"""
|
|||
|
使用pandas的sql引擎执行
|
|||
|
:param page_size:每页的记录数
|
|||
|
:param page_number:第N页
|
|||
|
:param model:
|
|||
|
:param order_col:
|
|||
|
:return:
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 判断表是否存在
|
|||
|
if self.has_table(model):
|
|||
|
query = self.session.query(model)
|
|||
|
if order_col is not None:
|
|||
|
query = query.order_by(order_col)
|
|||
|
if page_number is not None and page_size is not None:
|
|||
|
offset = (page_number - 1) * page_size
|
|||
|
query = query.offset(offset).limit(page_size)
|
|||
|
return pandas.read_sql_query(query.statement, self.engine.connect())
|
|||
|
return pandas.DataFrame()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
self.engine.dispose()
|
|||
|
|
|||
|
def pandas_query_by_sql(self, stmt=""):
|
|||
|
"""
|
|||
|
使用pandas的sql引擎执行
|
|||
|
:param stmt:
|
|||
|
:return:
|
|||
|
"""
|
|||
|
try:
|
|||
|
return pandas.read_sql_query(sql=stmt, con=self.engine.connect())
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
self.engine.dispose()
|
|||
|
|
|||
|
def pandas_query_by_condition(self, model, query_condition):
|
|||
|
try:
|
|||
|
# 当需要根据多个条件进行查询操作时
|
|||
|
# query_condition = and_(
|
|||
|
# StockDaily.trade_date == '20230823',
|
|||
|
# StockDaily.symbol == 'ABC'
|
|||
|
# )
|
|||
|
query = self.session.query(model).filter(query_condition).order_by(model.id)
|
|||
|
return self.pandas_query_by_sql(stmt=query.statement).reset_index()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
# ===================== delete 方法=============================
|
|||
|
def delete_by_id(self, model, db_id):
|
|||
|
try:
|
|||
|
# 使用 delete() 方法删除符合条件的记录
|
|||
|
self.session.query(model).filter_by(id=db_id).delete()
|
|||
|
self.session.commit()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def delete_by_condition(self, model, delete_condition):
|
|||
|
try:
|
|||
|
# 使用 delete() 方法删除符合条件的记录
|
|||
|
# 定义要删除的记录的条件
|
|||
|
# 例如,假设你要删除 trade_date 为 '20230823' 的记录
|
|||
|
# delete_condition = StockDaily.trade_date == '20230823'
|
|||
|
# 当需要根据多个条件进行删除操作时
|
|||
|
# delete_condition = and_(
|
|||
|
# StockDaily.trade_date == '20230823',
|
|||
|
# StockDaily.symbol == 'ABC'
|
|||
|
# )
|
|||
|
self.session.query(model).filter(delete_condition).delete()
|
|||
|
self.session.commit()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def delete_all_table(self, model): # 清空表数据
|
|||
|
try:
|
|||
|
self.session.query(model).delete()
|
|||
|
self.session.commit()
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
# ===================== 其它 方法=============================
|
|||
|
def has_table(self, entries):
|
|||
|
if isinstance(entries, list):
|
|||
|
table_name = entries[0].__tablename__
|
|||
|
else:
|
|||
|
table_name = entries.__tablename__
|
|||
|
# 检查表是否存在,如果不存在则创建
|
|||
|
return self.inspector.has_table(table_name)
|
|||
|
|
|||
|
def execute_sql(self, s):
|
|||
|
try:
|
|||
|
sql_text = text(s)
|
|||
|
return self.session.execute(sql_text)
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def execute_sql_to_pandas(self, s):
|
|||
|
try:
|
|||
|
sql_text = text(s)
|
|||
|
res = self.session.execute(sql_text)
|
|||
|
return pandas.DataFrame(res.fetchall(), columns=res.keys())
|
|||
|
except Exception as e:
|
|||
|
trace = traceback.extract_tb(e.__traceback__)
|
|||
|
for filename, lineno, funcname, source in trace:
|
|||
|
print(f"在文件 {filename} 的第 {lineno} 行发生错误 ,方法名称:{funcname} 发生错误的源码: {source}"
|
|||
|
f"错误内容:{traceback.format_exc()}")
|
|||
|
finally:
|
|||
|
self.session.close()
|
|||
|
|
|||
|
def close(self):
|
|||
|
self.session.close()
|