mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
feat:switch BaseDao from meta_data (#693)
This commit is contained in:
commit
a833e8b045
@ -1,13 +1,11 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ChatFeedBackEntity(Base):
|
||||
__tablename__ = "chat_feed_back"
|
||||
@ -33,13 +31,15 @@ class ChatFeedBackEntity(Base):
|
||||
|
||||
class ChatFeedBackDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
|
||||
super().__init__(
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||
# Todo: We need to have user information first.
|
||||
def_user_name = ""
|
||||
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
chat_feed_back = ChatFeedBackEntity(
|
||||
conv_uid=feed_back.conv_uid,
|
||||
conv_index=feed_back.conv_index,
|
||||
@ -73,7 +73,7 @@ class ChatFeedBackDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
result = (
|
||||
session.query(ChatFeedBackEntity)
|
||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||
|
@ -2,15 +2,13 @@ from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
||||
CFG = Config()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocumentChunkEntity(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
@ -30,11 +28,11 @@ class DocumentChunkEntity(Base):
|
||||
class DocumentChunkDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="knowledge_management", orm_base=Base, create_not_exist_table=True
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def create_documents_chunks(self, documents: List):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
docs = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=document.doc_name,
|
||||
@ -52,7 +50,7 @@ class DocumentChunkDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
document_chunks = session.query(DocumentChunkEntity)
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
@ -82,7 +80,7 @@ class DocumentChunkDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
@ -107,13 +105,13 @@ class DocumentChunkDao(BaseDao):
|
||||
return count
|
||||
|
||||
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||
# session = self.Session()
|
||||
# session = self.get_session()
|
||||
# updated_space = session.merge(document)
|
||||
# session.commit()
|
||||
# return updated_space.id
|
||||
|
||||
def delete(self, document_id: int):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
if document_id is None:
|
||||
raise Exception("document_id is None")
|
||||
query = DocumentChunkEntity(document_id=document_id)
|
||||
|
@ -1,15 +1,13 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
from sqlalchemy.orm import declarative_base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
||||
CFG = Config()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
__tablename__ = "knowledge_document"
|
||||
@ -33,11 +31,11 @@ class KnowledgeDocumentEntity(Base):
|
||||
class KnowledgeDocumentDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="knowledge_management", orm_base=Base, create_not_exist_table=True
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_document = KnowledgeDocumentEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
@ -58,7 +56,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return doc_id
|
||||
|
||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
@ -92,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_documents(self, query):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
@ -123,7 +121,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_knowledge_documents_count(self, query):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
@ -150,14 +148,14 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return count
|
||||
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
#
|
||||
def delete(self, query: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
|
@ -1,14 +1,13 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
||||
CFG = Config()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Base):
|
||||
@ -29,11 +28,11 @@ class KnowledgeSpaceEntity(Base):
|
||||
class KnowledgeSpaceDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="knowledge_management", orm_base=Base, create_not_exist_table=True
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||
@ -47,7 +46,7 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
@ -86,14 +85,14 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
session.merge(space)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
||||
|
||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
if space:
|
||||
session.delete(space)
|
||||
session.commit()
|
||||
|
@ -1,15 +1,14 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.rdbms.base_dao import BaseDao
|
||||
|
||||
from pilot.server.prompt.request.request import PromptManageRequest
|
||||
|
||||
CFG = Config()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PromptManageEntity(Base):
|
||||
@ -31,11 +30,11 @@ class PromptManageEntity(Base):
|
||||
class PromptManageDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="prompt_management", orm_base=Base, create_not_exist_table=True
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def create_prompt(self, prompt: PromptManageRequest):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
prompt_manage = PromptManageEntity(
|
||||
chat_scene=prompt.chat_scene,
|
||||
sub_chat_scene=prompt.sub_chat_scene,
|
||||
@ -51,7 +50,7 @@ class PromptManageDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_prompts(self, query: PromptManageEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
prompts = session.query(PromptManageEntity)
|
||||
if query.chat_scene is not None:
|
||||
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||
@ -78,13 +77,13 @@ class PromptManageDao(BaseDao):
|
||||
return result
|
||||
|
||||
def update_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
session.merge(prompt)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def delete_prompt(self, prompt: PromptManageEntity):
|
||||
session = self.Session()
|
||||
session = self.get_session()
|
||||
if prompt:
|
||||
session.delete(prompt)
|
||||
session.commit()
|
||||
|
Loading…
Reference in New Issue
Block a user