refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -3,19 +3,13 @@ from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
CFG = Config()
class DocumentChunkEntity(Base):
class DocumentChunkEntity(Model):
__tablename__ = "document_chunk"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -35,16 +29,8 @@ class DocumentChunkEntity(Base):
class DocumentChunkDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_documents_chunks(self, documents: List):
session = self.get_session()
session = self.get_raw_session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
@@ -64,7 +50,7 @@ class DocumentChunkDao(BaseDao):
def get_document_chunks(
self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None
):
session = self.get_session()
session = self.get_raw_session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@@ -102,7 +88,7 @@ class DocumentChunkDao(BaseDao):
return result
def get_document_chunks_count(self, query: DocumentChunkEntity):
session = self.get_session()
session = self.get_raw_session()
document_chunks = session.query(func.count(DocumentChunkEntity.id))
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
@@ -127,7 +113,7 @@ class DocumentChunkDao(BaseDao):
return count
def delete(self, document_id: int):
session = self.get_session()
session = self.get_raw_session()
if document_id is None:
raise Exception("document_id is None")
query = DocumentChunkEntity(document_id=document_id)

View File

@@ -2,19 +2,13 @@ from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, func
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
CFG = Config()
class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentEntity(Model):
__tablename__ = "knowledge_document"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -39,16 +33,8 @@ class KnowledgeDocumentEntity(Base):
class KnowledgeDocumentDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.get_session()
session = self.get_raw_session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
@@ -69,7 +55,7 @@ class KnowledgeDocumentDao(BaseDao):
return doc_id
def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.get_session()
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
@@ -104,7 +90,7 @@ class KnowledgeDocumentDao(BaseDao):
return result
def get_documents(self, query):
session = self.get_session()
session = self.get_raw_session()
print(f"current session:{session}")
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
@@ -136,7 +122,7 @@ class KnowledgeDocumentDao(BaseDao):
return result
def get_knowledge_documents_count_bulk(self, space_names):
session = self.get_session()
session = self.get_raw_session()
"""
Perform a batch query to count the number of documents for each knowledge space.
@@ -161,7 +147,7 @@ class KnowledgeDocumentDao(BaseDao):
return docs_count
def get_knowledge_documents_count(self, query):
session = self.get_session()
session = self.get_raw_session()
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
if query.id is not None:
knowledge_documents = knowledge_documents.filter(
@@ -188,14 +174,14 @@ class KnowledgeDocumentDao(BaseDao):
return count
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.get_session()
session = self.get_raw_session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
#
def delete(self, query: KnowledgeDocumentEntity):
session = self.get_session()
session = self.get_raw_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(

View File

@@ -2,20 +2,14 @@ from datetime import datetime
from sqlalchemy import Column, Integer, Text, String, DateTime
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
CFG = Config()
class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceEntity(Model):
__tablename__ = "knowledge_space"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -35,16 +29,8 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao(BaseDao):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.get_session()
session = self.get_raw_session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=CFG.VECTOR_STORE_TYPE,
@@ -58,7 +44,7 @@ class KnowledgeSpaceDao(BaseDao):
session.close()
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.get_session()
session = self.get_raw_session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(
@@ -97,14 +83,14 @@ class KnowledgeSpaceDao(BaseDao):
return result
def update_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.get_session()
session = self.get_raw_session()
session.merge(space)
session.commit()
session.close()
return True
def delete_knowledge_space(self, space: KnowledgeSpaceEntity):
session = self.get_session()
session = self.get_raw_session()
if space:
session.delete(space)
session.commit()