feat:switch BaseDao from meta_data (#693)

This commit is contained in:
magic.chen 2023-10-19 14:54:16 +08:00 committed by GitHub
commit a833e8b045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 44 deletions

View File

@ -1,13 +1,11 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.connections.rdbms.base_dao import BaseDao from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
Base = declarative_base()
class ChatFeedBackEntity(Base): class ChatFeedBackEntity(Base):
__tablename__ = "chat_feed_back" __tablename__ = "chat_feed_back"
@ -33,13 +31,15 @@ class ChatFeedBackEntity(Base):
class ChatFeedBackDao(BaseDao): class ChatFeedBackDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__(database="history", orm_base=Base, create_not_exist_table=True) super().__init__(
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody): def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first. # Todo: We need to have user information first.
def_user_name = "" def_user_name = ""
session = self.Session() session = self.get_session()
chat_feed_back = ChatFeedBackEntity( chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid, conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index, conv_index=feed_back.conv_index,
@ -73,7 +73,7 @@ class ChatFeedBackDao(BaseDao):
session.close() session.close()
def get_chat_feed_back(self, conv_uid: str, conv_index: int): def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.Session() session = self.get_session()
result = ( result = (
session.query(ChatFeedBackEntity) session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid) .filter(ChatFeedBackEntity.conv_uid == conv_uid)

View File

@ -2,15 +2,13 @@ from datetime import datetime
from typing import List from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base()
class DocumentChunkEntity(Base): class DocumentChunkEntity(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
@ -30,11 +28,11 @@ class DocumentChunkEntity(Base):
class DocumentChunkDao(BaseDao): class DocumentChunkDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_documents_chunks(self, documents: List): def create_documents_chunks(self, documents: List):
session = self.Session() session = self.get_session()
docs = [ docs = [
DocumentChunkEntity( DocumentChunkEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
@ -52,7 +50,7 @@ class DocumentChunkDao(BaseDao):
session.close() session.close()
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20): def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session() session = self.get_session()
document_chunks = session.query(DocumentChunkEntity) document_chunks = session.query(DocumentChunkEntity)
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@ -82,7 +80,7 @@ class DocumentChunkDao(BaseDao):
return result return result
def get_document_chunks_count(self, query: DocumentChunkEntity): def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session() session = self.get_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id)) document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@ -107,13 +105,13 @@ class DocumentChunkDao(BaseDao):
return count return count
# def update_knowledge_document(self, document:KnowledgeDocumentEntity): # def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session() # session = self.get_session()
# updated_space = session.merge(document) # updated_space = session.merge(document)
# session.commit() # session.commit()
# return updated_space.id # return updated_space.id
def delete(self, document_id: int): def delete(self, document_id: int):
session = self.Session() session = self.get_session()
if document_id is None: if document_id is None:
raise Exception("document_id is None") raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id) query = DocumentChunkEntity(document_id=document_id)

View File

@ -1,15 +1,13 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base()
class KnowledgeDocumentEntity(Base): class KnowledgeDocumentEntity(Base):
__tablename__ = "knowledge_document" __tablename__ = "knowledge_document"
@ -33,11 +31,11 @@ class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentDao(BaseDao): class KnowledgeDocumentDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_knowledge_document(self, document: KnowledgeDocumentEntity): def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.get_session()
knowledge_document = KnowledgeDocumentEntity( knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
doc_type=document.doc_type, doc_type=document.doc_type,
@ -58,7 +56,7 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20): def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@ -92,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
return result return result
def get_documents(self, query): def get_documents(self, query):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@ -123,7 +121,7 @@ class KnowledgeDocumentDao(BaseDao):
return result return result
def get_knowledge_documents_count(self, query): def get_knowledge_documents_count(self, query):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@ -150,14 +148,14 @@ class KnowledgeDocumentDao(BaseDao):
return count return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity): def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.get_session()
updated_space = session.merge(document) updated_space = session.merge(document)
session.commit() session.commit()
return updated_space.id return updated_space.id
# #
def delete(self, query: KnowledgeDocumentEntity): def delete(self, query: KnowledgeDocumentEntity):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(

View File

@ -1,14 +1,13 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base()
class KnowledgeSpaceEntity(Base): class KnowledgeSpaceEntity(Base):
@ -29,11 +28,11 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao(BaseDao): class KnowledgeSpaceDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_knowledge_space(self, space: KnowledgeSpaceRequest): def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session() session = self.get_session()
knowledge_space = KnowledgeSpaceEntity( knowledge_space = KnowledgeSpaceEntity(
name=space.name, name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE, vector_type=CFG.VECTOR_STORE_TYPE,
@ -47,7 +46,7 @@ class KnowledgeSpaceDao(BaseDao):
session.close() session.close()
def get_knowledge_space(self, query: KnowledgeSpaceEntity): def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session() session = self.get_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity) knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None: if query.id is not None:
knowledge_spaces = knowledge_spaces.filter( knowledge_spaces = knowledge_spaces.filter(
@ -86,14 +85,14 @@ class KnowledgeSpaceDao(BaseDao):
return result return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity): def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session() session = self.get_session()
session.merge(space) session.merge(space)
session.commit() session.commit()
session.close() session.close()
return True return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity): def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session() session = self.get_session()
if space: if space:
session.delete(space) session.delete(space)
session.commit() session.commit()

View File

@ -1,15 +1,14 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.server.prompt.request.request import PromptManageRequest from pilot.server.prompt.request.request import PromptManageRequest
CFG = Config() CFG = Config()
Base = declarative_base()
class PromptManageEntity(Base): class PromptManageEntity(Base):
@ -31,11 +30,11 @@ class PromptManageEntity(Base):
class PromptManageDao(BaseDao): class PromptManageDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_prompt(self, prompt: PromptManageRequest): def create_prompt(self, prompt: PromptManageRequest):
session = self.Session() session = self.get_session()
prompt_manage = PromptManageEntity( prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene, chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene, sub_chat_scene=prompt.sub_chat_scene,
@ -51,7 +50,7 @@ class PromptManageDao(BaseDao):
session.close() session.close()
def get_prompts(self, query: PromptManageEntity): def get_prompts(self, query: PromptManageEntity):
session = self.Session() session = self.get_session()
prompts = session.query(PromptManageEntity) prompts = session.query(PromptManageEntity)
if query.chat_scene is not None: if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
@ -78,13 +77,13 @@ class PromptManageDao(BaseDao):
return result return result
def update_prompt(self, prompt: PromptManageEntity): def update_prompt(self, prompt: PromptManageEntity):
session = self.Session() session = self.get_session()
session.merge(prompt) session.merge(prompt)
session.commit() session.commit()
session.close() session.close()
def delete_prompt(self, prompt: PromptManageEntity): def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session() session = self.get_session()
if prompt: if prompt:
session.delete(prompt) session.delete(prompt)
session.commit() session.commit()