feat:switch BaseDao from meta_data

This commit is contained in:
aries_ckt 2023-10-19 14:33:25 +08:00
parent 9e5a7bea1f
commit 9b662c09f1
4 changed files with 31 additions and 37 deletions

View File

@ -2,15 +2,13 @@ from datetime import datetime
from typing import List from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base()
class DocumentChunkEntity(Base): class DocumentChunkEntity(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
@ -30,11 +28,11 @@ class DocumentChunkEntity(Base):
class DocumentChunkDao(BaseDao): class DocumentChunkDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_documents_chunks(self, documents: List): def create_documents_chunks(self, documents: List):
session = self.Session() session = self.get_session()
docs = [ docs = [
DocumentChunkEntity( DocumentChunkEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
@ -52,7 +50,7 @@ class DocumentChunkDao(BaseDao):
session.close() session.close()
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20): def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session() session = self.get_session()
document_chunks = session.query(DocumentChunkEntity) document_chunks = session.query(DocumentChunkEntity)
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@ -82,7 +80,7 @@ class DocumentChunkDao(BaseDao):
return result return result
def get_document_chunks_count(self, query: DocumentChunkEntity): def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.Session() session = self.get_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id)) document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@ -107,13 +105,13 @@ class DocumentChunkDao(BaseDao):
return count return count
# def update_knowledge_document(self, document:KnowledgeDocumentEntity): # def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session() # session = self.get_session()
# updated_space = session.merge(document) # updated_space = session.merge(document)
# session.commit() # session.commit()
# return updated_space.id # return updated_space.id
def delete(self, document_id: int): def delete(self, document_id: int):
session = self.Session() session = self.get_session()
if document_id is None: if document_id is None:
raise Exception("document_id is None") raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id) query = DocumentChunkEntity(document_id=document_id)

View File

@ -1,15 +1,13 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func from sqlalchemy import Column, String, DateTime, Integer, Text, func
from sqlalchemy.orm import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base()
class KnowledgeDocumentEntity(Base): class KnowledgeDocumentEntity(Base):
__tablename__ = "knowledge_document" __tablename__ = "knowledge_document"
@ -33,11 +31,11 @@ class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentDao(BaseDao): class KnowledgeDocumentDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_knowledge_document(self, document: KnowledgeDocumentEntity): def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.get_session()
knowledge_document = KnowledgeDocumentEntity( knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name, doc_name=document.doc_name,
doc_type=document.doc_type, doc_type=document.doc_type,
@ -58,7 +56,7 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20): def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@ -92,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
return result return result
def get_documents(self, query): def get_documents(self, query):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@ -123,7 +121,7 @@ class KnowledgeDocumentDao(BaseDao):
return result return result
def get_knowledge_documents_count(self, query): def get_knowledge_documents_count(self, query):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id)) knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(
@ -150,14 +148,14 @@ class KnowledgeDocumentDao(BaseDao):
return count return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity): def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session() session = self.get_session()
updated_space = session.merge(document) updated_space = session.merge(document)
session.commit() session.commit()
return updated_space.id return updated_space.id
# #
def delete(self, query: KnowledgeDocumentEntity): def delete(self, query: KnowledgeDocumentEntity):
session = self.Session() session = self.get_session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter( knowledge_documents = knowledge_documents.filter(

View File

@ -1,14 +1,13 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.connections.rdbms.base_dao import BaseDao
CFG = Config() CFG = Config()
Base = declarative_base()
class KnowledgeSpaceEntity(Base): class KnowledgeSpaceEntity(Base):
@ -29,11 +28,11 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao(BaseDao): class KnowledgeSpaceDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="knowledge_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_knowledge_space(self, space: KnowledgeSpaceRequest): def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session() session = self.get_session()
knowledge_space = KnowledgeSpaceEntity( knowledge_space = KnowledgeSpaceEntity(
name=space.name, name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE, vector_type=CFG.VECTOR_STORE_TYPE,
@ -47,7 +46,7 @@ class KnowledgeSpaceDao(BaseDao):
session.close() session.close()
def get_knowledge_space(self, query: KnowledgeSpaceEntity): def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session() session = self.get_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity) knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None: if query.id is not None:
knowledge_spaces = knowledge_spaces.filter( knowledge_spaces = knowledge_spaces.filter(
@ -86,14 +85,14 @@ class KnowledgeSpaceDao(BaseDao):
return result return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity): def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session() session = self.get_session()
session.merge(space) session.merge(space)
session.commit() session.commit()
session.close() session.close()
return True return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity): def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.Session() session = self.get_session()
if space: if space:
session.delete(space) session.delete(space)
session.commit() session.commit()

View File

@ -1,15 +1,14 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime from sqlalchemy import Column, Integer, Text, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.connections.rdbms.base_dao import BaseDao
from pilot.server.prompt.request.request import PromptManageRequest from pilot.server.prompt.request.request import PromptManageRequest
CFG = Config() CFG = Config()
Base = declarative_base()
class PromptManageEntity(Base): class PromptManageEntity(Base):
@ -31,11 +30,11 @@ class PromptManageEntity(Base):
class PromptManageDao(BaseDao): class PromptManageDao(BaseDao):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="prompt_management", orm_base=Base, create_not_exist_table=True database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def create_prompt(self, prompt: PromptManageRequest): def create_prompt(self, prompt: PromptManageRequest):
session = self.Session() session = self.get_session()
prompt_manage = PromptManageEntity( prompt_manage = PromptManageEntity(
chat_scene=prompt.chat_scene, chat_scene=prompt.chat_scene,
sub_chat_scene=prompt.sub_chat_scene, sub_chat_scene=prompt.sub_chat_scene,
@ -51,7 +50,7 @@ class PromptManageDao(BaseDao):
session.close() session.close()
def get_prompts(self, query: PromptManageEntity): def get_prompts(self, query: PromptManageEntity):
session = self.Session() session = self.get_session()
prompts = session.query(PromptManageEntity) prompts = session.query(PromptManageEntity)
if query.chat_scene is not None: if query.chat_scene is not None:
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene) prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
@ -78,13 +77,13 @@ class PromptManageDao(BaseDao):
return result return result
def update_prompt(self, prompt: PromptManageEntity): def update_prompt(self, prompt: PromptManageEntity):
session = self.Session() session = self.get_session()
session.merge(prompt) session.merge(prompt)
session.commit() session.commit()
session.close() session.close()
def delete_prompt(self, prompt: PromptManageEntity): def delete_prompt(self, prompt: PromptManageEntity):
session = self.Session() session = self.get_session()
if prompt: if prompt:
session.delete(prompt) session.delete(prompt)
session.commit() session.commit()