feat:switch BaseDao from meta_data (#693)

This commit is contained in:
magic.chen 2023-10-19 14:54:16 +08:00 committed by GitHub
commit a833e8b045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 44 deletions

View File

@ -1,13 +1,11 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
Base = declarative_base()
class ChatFeedBackEntity(Base):
__tablename__ = "chat_feed_back"
@ -33,13 +31,15 @@ class ChatFeedBackEntity(Base):
class ChatFeedBackDao(BaseDao):
def __init__(self):
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
super().__init__(
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
# Todo: We need to have user information first.
def_user_name = ""
session = self.Session()
session = self.get_session()
chat_feed_back = ChatFeedBackEntity(
conv_uid=feed_back.conv_uid,
conv_index=feed_back.conv_index,
@ -73,7 +73,7 @@ class ChatFeedBackDao(BaseDao):
session.close()
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
session = self.Session()
session = self.get_session()
result = (
session.query(ChatFeedBackEntity)
.filter(ChatFeedBackEntity.conv_uid == conv_uid)

View File

@ -2,15 +2,13 @@ from datetime import datetime
from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config()
Base = declarative_base()
class DocumentChunkEntity(Base):
__tablename__ = "document_chunk"
@ -30,11 +28,11 @@ class DocumentChunkEntity(Base):
class DocumentChunkDao(BaseDao):
def __init__(self):
super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def create_documents_chunks(self, documents: List):
session = self.Session()
session = self.get_session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
@ -52,7 +50,7 @@ class DocumentChunkDao(BaseDao):
session.close()
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session()
session = self.get_session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@ -82,7 +80,7 @@ class DocumentChunkDao(BaseDao):
return result
def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session()
session = self.get_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@ -107,13 +105,13 @@ class DocumentChunkDao(BaseDao):
return count
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session()
# session = self.get_session()
# updated_space = session.merge(document)
# session.commit()
# return updated_space.id
def delete(self, document_id: int):
session = self.Session()
session = self.get_session()
if document_id is None:
raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id)

View File

@ -1,15 +1,13 @@
from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config()
Base = declarative_base()
class KnowledgeDocumentEntity(Base):
__tablename__ = "knowledge_document"
@ -33,11 +31,11 @@ class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentDao(BaseDao):
def __init__(self):
super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
session = self.get_session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
@ -58,7 +56,7 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
@ -92,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
return result
def get_documents(self, query):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
@ -123,7 +121,7 @@ class KnowledgeDocumentDao(BaseDao):
return result
def get_knowledge_documents_count(self, query):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
@ -150,14 +148,14 @@ class KnowledgeDocumentDao(BaseDao):
return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
session = self.get_session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
#
def delete(self, query: KnowledgeDocumentEntity):
session = self.Session()
session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(

View File

@ -1,14 +1,13 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config()
Base = declarative_base()
class KnowledgeSpaceEntity(Base):
@ -29,11 +28,11 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao(BaseDao):
def __init__(self):
super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session()
session = self.get_session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE,
@ -47,7 +46,7 @@ class KnowledgeSpaceDao(BaseDao):
session.close()
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session()
session = self.get_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(
@ -86,14 +85,14 @@ class KnowledgeSpaceDao(BaseDao):
return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session()
session = self.get_session()
session.merge(space)
session.commit()
session.close()
return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session()
session = self.get_session()
if space:
session.delete(space)
session.commit()

View File

@ -1,15 +1,14 @@
from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.server.prompt.request.request import PromptManageRequest
CFG = Config()
Base = declarative_base()
class PromptManageEntity(Base):
@ -31,11 +30,11 @@ class PromptManageEntity(Base):
class PromptManageDao(BaseDao):
def __init__(self):
super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def create_prompt(self, prompt: PromptManageRequest):
session = self.Session()
session = self.get_session()
prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene,
@ -51,7 +50,7 @@ class PromptManageDao(BaseDao):
session.close()
def get_prompts(self, query: PromptManageEntity):
session = self.Session()
session = self.get_session()
prompts = session.query(PromptManageEntity)
if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
@ -78,13 +77,13 @@ class PromptManageDao(BaseDao):
return result
def update_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session = self.get_session()
session.merge(prompt)
session.commit()
session.close()
def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session()
session = self.get_session()
if prompt:
session.delete(prompt)
session.commit()