mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 02:20:08 +00:00
feat:switch BaseDao from meta_data (#693)
This commit is contained in:
commit
a833e8b045
@ -1,13 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
from pilot.openapi.api_v1.feedback.feed_back_model import FeedBackBody
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class ChatFeedBackEntity(Base):
|
class ChatFeedBackEntity(Base):
|
||||||
__tablename__ = "chat_feed_back"
|
__tablename__ = "chat_feed_back"
|
||||||
@ -33,13 +31,15 @@ class ChatFeedBackEntity(Base):
|
|||||||
|
|
||||||
class ChatFeedBackDao(BaseDao):
|
class ChatFeedBackDao(BaseDao):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(database="history", orm_base=Base, create_not_exist_table=True)
|
super().__init__(
|
||||||
|
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||||
|
)
|
||||||
|
|
||||||
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
def create_or_update_chat_feed_back(self, feed_back: FeedBackBody):
|
||||||
# Todo: We need to have user information first.
|
# Todo: We need to have user information first.
|
||||||
def_user_name = ""
|
def_user_name = ""
|
||||||
|
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
chat_feed_back = ChatFeedBackEntity(
|
chat_feed_back = ChatFeedBackEntity(
|
||||||
conv_uid=feed_back.conv_uid,
|
conv_uid=feed_back.conv_uid,
|
||||||
conv_index=feed_back.conv_index,
|
conv_index=feed_back.conv_index,
|
||||||
@ -73,7 +73,7 @@ class ChatFeedBackDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
def get_chat_feed_back(self, conv_uid: str, conv_index: int):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
result = (
|
result = (
|
||||||
session.query(ChatFeedBackEntity)
|
session.query(ChatFeedBackEntity)
|
||||||
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
.filter(ChatFeedBackEntity.conv_uid == conv_uid)
|
||||||
|
@ -2,15 +2,13 @@ from datetime import datetime
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunkEntity(Base):
|
class DocumentChunkEntity(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
@ -30,11 +28,11 @@ class DocumentChunkEntity(Base):
|
|||||||
class DocumentChunkDao(BaseDao):
|
class DocumentChunkDao(BaseDao):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
database="knowledge_management", orm_base=Base, create_not_exist_table=True
|
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_documents_chunks(self, documents: List):
|
def create_documents_chunks(self, documents: List):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
docs = [
|
docs = [
|
||||||
DocumentChunkEntity(
|
DocumentChunkEntity(
|
||||||
doc_name=document.doc_name,
|
doc_name=document.doc_name,
|
||||||
@ -52,7 +50,7 @@ class DocumentChunkDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
|
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
document_chunks = session.query(DocumentChunkEntity)
|
document_chunks = session.query(DocumentChunkEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||||
@ -82,7 +80,7 @@ class DocumentChunkDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||||
@ -107,13 +105,13 @@ class DocumentChunkDao(BaseDao):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
|
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||||
# session = self.Session()
|
# session = self.get_session()
|
||||||
# updated_space = session.merge(document)
|
# updated_space = session.merge(document)
|
||||||
# session.commit()
|
# session.commit()
|
||||||
# return updated_space.id
|
# return updated_space.id
|
||||||
|
|
||||||
def delete(self, document_id: int):
|
def delete(self, document_id: int):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
if document_id is None:
|
if document_id is None:
|
||||||
raise Exception("document_id is None")
|
raise Exception("document_id is None")
|
||||||
query = DocumentChunkEntity(document_id=document_id)
|
query = DocumentChunkEntity(document_id=document_id)
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeDocumentEntity(Base):
|
class KnowledgeDocumentEntity(Base):
|
||||||
__tablename__ = "knowledge_document"
|
__tablename__ = "knowledge_document"
|
||||||
@ -33,11 +31,11 @@ class KnowledgeDocumentEntity(Base):
|
|||||||
class KnowledgeDocumentDao(BaseDao):
|
class KnowledgeDocumentDao(BaseDao):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
database="knowledge_management", orm_base=Base, create_not_exist_table=True
|
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_document = KnowledgeDocumentEntity(
|
knowledge_document = KnowledgeDocumentEntity(
|
||||||
doc_name=document.doc_name,
|
doc_name=document.doc_name,
|
||||||
doc_type=document.doc_type,
|
doc_type=document.doc_type,
|
||||||
@ -58,7 +56,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return doc_id
|
return doc_id
|
||||||
|
|
||||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(
|
knowledge_documents = knowledge_documents.filter(
|
||||||
@ -92,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_documents(self, query):
|
def get_documents(self, query):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(
|
knowledge_documents = knowledge_documents.filter(
|
||||||
@ -123,7 +121,7 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_knowledge_documents_count(self, query):
|
def get_knowledge_documents_count(self, query):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(
|
knowledge_documents = knowledge_documents.filter(
|
||||||
@ -150,14 +148,14 @@ class KnowledgeDocumentDao(BaseDao):
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
updated_space = session.merge(document)
|
updated_space = session.merge(document)
|
||||||
session.commit()
|
session.commit()
|
||||||
return updated_space.id
|
return updated_space.id
|
||||||
|
|
||||||
#
|
#
|
||||||
def delete(self, query: KnowledgeDocumentEntity):
|
def delete(self, query: KnowledgeDocumentEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_documents = knowledge_documents.filter(
|
knowledge_documents = knowledge_documents.filter(
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeSpaceEntity(Base):
|
class KnowledgeSpaceEntity(Base):
|
||||||
@ -29,11 +28,11 @@ class KnowledgeSpaceEntity(Base):
|
|||||||
class KnowledgeSpaceDao(BaseDao):
|
class KnowledgeSpaceDao(BaseDao):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
database="knowledge_management", orm_base=Base, create_not_exist_table=True
|
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_space = KnowledgeSpaceEntity(
|
knowledge_space = KnowledgeSpaceEntity(
|
||||||
name=space.name,
|
name=space.name,
|
||||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||||
@ -47,7 +46,7 @@ class KnowledgeSpaceDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
knowledge_spaces = knowledge_spaces.filter(
|
knowledge_spaces = knowledge_spaces.filter(
|
||||||
@ -86,14 +85,14 @@ class KnowledgeSpaceDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
session.merge(space)
|
session.merge(space)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
if space:
|
if space:
|
||||||
session.delete(space)
|
session.delete(space)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -1,15 +1,14 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.connections.rdbms.base_dao import BaseDao
|
|
||||||
|
|
||||||
from pilot.server.prompt.request.request import PromptManageRequest
|
from pilot.server.prompt.request.request import PromptManageRequest
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
class PromptManageEntity(Base):
|
class PromptManageEntity(Base):
|
||||||
@ -31,11 +30,11 @@ class PromptManageEntity(Base):
|
|||||||
class PromptManageDao(BaseDao):
|
class PromptManageDao(BaseDao):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
database="prompt_management", orm_base=Base, create_not_exist_table=True
|
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_prompt(self, prompt: PromptManageRequest):
|
def create_prompt(self, prompt: PromptManageRequest):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
prompt_manage = PromptManageEntity(
|
prompt_manage = PromptManageEntity(
|
||||||
chat_scene=prompt.chat_scene,
|
chat_scene=prompt.chat_scene,
|
||||||
sub_chat_scene=prompt.sub_chat_scene,
|
sub_chat_scene=prompt.sub_chat_scene,
|
||||||
@ -51,7 +50,7 @@ class PromptManageDao(BaseDao):
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def get_prompts(self, query: PromptManageEntity):
|
def get_prompts(self, query: PromptManageEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
prompts = session.query(PromptManageEntity)
|
prompts = session.query(PromptManageEntity)
|
||||||
if query.chat_scene is not None:
|
if query.chat_scene is not None:
|
||||||
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
prompts = prompts.filter(PromptManageEntity.chat_scene == query.chat_scene)
|
||||||
@ -78,13 +77,13 @@ class PromptManageDao(BaseDao):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def update_prompt(self, prompt: PromptManageEntity):
|
def update_prompt(self, prompt: PromptManageEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
session.merge(prompt)
|
session.merge(prompt)
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def delete_prompt(self, prompt: PromptManageEntity):
|
def delete_prompt(self, prompt: PromptManageEntity):
|
||||||
session = self.Session()
|
session = self.get_session()
|
||||||
if prompt:
|
if prompt:
|
||||||
session.delete(prompt)
|
session.delete(prompt)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
Loading…
Reference in New Issue
Block a user