diff --git a/pilot/openapi/api_v1/feedback/feed_back_db.py b/pilot/openapi/api_v1/feedback/feed_back_db.py index 2b57c4bde..64b99b8de 100644 --- a/pilot/openapi/api_v1/feedback/feed_back_db.py +++ b/pilot/openapi/api_v1/feedback/feed_back_db.py @@ -1,13 +1,11 @@ from datetime import datetime from sqlalchemy import Column, Integer, Text, String, DateTime -from sqlalchemy.ext.declarative import declarative_base -from pilot.connections.rdbms.base_dao import BaseDao +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody -Base = declarative_base() - class ChatFeedBackEntity(Base): __tablename__ = "chat_feed_back" @@ -33,13 +31,15 @@ class ChatFeedBackEntity(Base): class ChatFeedBackDao(BaseDao): def __init__(self): - super().__init__(database="history", orm_base=Base, create_not_exist_table=True) + super().__init__( + database="dbgpt", orm_base=Base, db_engine=engine, session=session + ) def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): # Todo: We need to have user information first. def_user_name = "" - session = self.Session() + session = self.get_session() chat_feed_back = ChatFeedBackEntity( conv_uid=feed_back.conv_uid, conv_index=feed_back.conv_index, @@ -73,7 +73,7 @@ class ChatFeedBackDao(BaseDao): session.close() def get_chat_feed_back(self, conv_uid: str, conv_index: int): - session = self.Session() + session = self.get_session() result = ( session.query(ChatFeedBackEntity) .filter(ChatFeedBackEntity.conv_uid == conv_uid)