mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -3,19 +3,13 @@ from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DocumentChunkEntity(Base):
|
||||
class DocumentChunkEntity(Model):
|
||||
__tablename__ = "document_chunk"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -35,16 +29,8 @@ class DocumentChunkEntity(Base):
|
||||
|
||||
|
||||
class DocumentChunkDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_documents_chunks(self, documents: List):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
docs = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=document.doc_name,
|
||||
@@ -64,7 +50,7 @@ class DocumentChunkDao(BaseDao):
|
||||
def get_document_chunks(
|
||||
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
|
||||
):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
document_chunks = session.query(DocumentChunkEntity)
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
@@ -102,7 +88,7 @@ class DocumentChunkDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
@@ -127,7 +113,7 @@ class DocumentChunkDao(BaseDao):
|
||||
return count
|
||||
|
||||
def delete(self, document_id: int):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if document_id is None:
|
||||
raise Exception("document_id is None")
|
||||
query = DocumentChunkEntity(document_id=document_id)
|
||||
|
@@ -2,19 +2,13 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
class KnowledgeDocumentEntity(Model):
|
||||
__tablename__ = "knowledge_document"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -39,16 +33,8 @@ class KnowledgeDocumentEntity(Base):
|
||||
|
||||
|
||||
class KnowledgeDocumentDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_document = KnowledgeDocumentEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
@@ -69,7 +55,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return doc_id
|
||||
|
||||
def get_knowledge_documents(self, query, page=1, page_size=20):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
@@ -104,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_documents(self, query):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
print(f"current session:{session}")
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
@@ -136,7 +122,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return result
|
||||
|
||||
def get_knowledge_documents_count_bulk(self, space_names):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
"""
|
||||
Perform a batch query to count the number of documents for each knowledge space.
|
||||
|
||||
@@ -161,7 +147,7 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return docs_count
|
||||
|
||||
def get_knowledge_documents_count(self, query):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
@@ -188,14 +174,14 @@ class KnowledgeDocumentDao(BaseDao):
|
||||
return count
|
||||
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
#
|
||||
def delete(self, query: KnowledgeDocumentEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
|
@@ -2,20 +2,14 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, Text, String, DateTime
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Base):
|
||||
class KnowledgeSpaceEntity(Model):
|
||||
__tablename__ = "knowledge_space"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -35,16 +29,8 @@ class KnowledgeSpaceEntity(Base):
|
||||
|
||||
|
||||
class KnowledgeSpaceDao(BaseDao):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
vector_type=CFG.VECTOR_STORE_TYPE,
|
||||
@@ -58,7 +44,7 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
@@ -97,14 +83,14 @@ class KnowledgeSpaceDao(BaseDao):
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
session.merge(space)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
||||
|
||||
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if space:
|
||||
session.delete(space)
|
||||
session.commit()
|
||||
|
Reference in New Issue
Block a user