feat: knowledge management backend api

1.create knowledge space
2.list knowledge space
3.create knowledge document
4.list knowledge document
5.save document chunks
6.sync embedding document
This commit is contained in:
aries_ckt 2023-06-26 15:24:25 +08:00
parent 364f251a12
commit db28894443
13 changed files with 648 additions and 12 deletions

View File

@ -48,9 +48,10 @@ class KnowledgeEmbedding:
def knowledge_embedding_batch(self, docs): def knowledge_embedding_batch(self, docs):
# docs = self.knowledge_embedding_client.read_batch() # docs = self.knowledge_embedding_client.read_batch()
self.knowledge_embedding_client.index_to_store(docs) return self.knowledge_embedding_client.index_to_store(docs)
def read(self): def read(self):
self.knowledge_embedding_client = self.init_knowledge_embedding()
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
@ -66,7 +67,7 @@ class KnowledgeEmbedding:
embedding = knowledge_class( embedding = knowledge_class(
self.file_path, self.file_path,
vector_store_config=self.vector_store_config, vector_store_config=self.vector_store_config,
**knowledge_args, **knowledge_args
) )
return embedding return embedding
raise ValueError(f"Unsupported knowledge file type '{extension}'") raise ValueError(f"Unsupported knowledge file type '{extension}'")

View File

@ -59,7 +59,7 @@ class SourceEmbedding(ABC):
@register @register
def index_to_store(self, docs): def index_to_store(self, docs):
"""index to vector store""" """index to vector store"""
self.vector_client.load_document(docs) return self.vector_client.load_document(docs)
@register @register
def similar_search(self, doc, topk): def similar_search(self, doc, topk):

View File

@ -42,7 +42,7 @@ class ChatUrlKnowledge(BaseChat):
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
} }
self.knowledge_embedding_client = KnowledgeEmbedding( self.knowledge_embedding_client = KnowledgeEmbedding(
model_name=LLM_MODEL_CONFIG["text2vec"], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
file_type="url", file_type="url",
file_path=url, file_path=url,

View File

@ -6,21 +6,21 @@ T = TypeVar('T')
class Result(Generic[T], BaseModel): class Result(Generic[T], BaseModel):
success: bool success: bool
err_code: str err_code: str = None
err_msg: str err_msg: str = None
data: List[T] data: List[T] = None
@classmethod @classmethod
def succ(cls, data: List[T]): def succ(cls, data: List[T]):
return Result(True, None, None, data) return Result(success=True, err_code=None, err_msg=None, data=data)
@classmethod @classmethod
def faild(cls, msg): def faild(cls, msg):
return Result(True, "E000X", msg, None) return Result(success=False, err_code="E000X", err_msg=msg, data=None)
@classmethod @classmethod
def faild(cls, code, msg): def faild(cls, code, msg):
return Result(True, code, msg, None) return Result(success=False, err_code=code, err_msg=msg, data=None)
class ConversationVo(BaseModel): class ConversationVo(BaseModel):

View File

@ -0,0 +1,83 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config
CFG = Config()
Base = declarative_base()
class DocumentChunkEntity(Base):
__tablename__ = 'document_chunk'
id = Column(Integer, primary_key=True)
document_id = Column(Integer)
doc_name = Column(String(100))
doc_type = Column(String(100))
content = Column(Text)
meta_info = Column(String(500))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"DocumentChunkEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', document_id='{self.document_id}', content='{self.content}', meta_info='{self.meta_info}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class DocumentChunkDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
echo=True)
self.Session = sessionmaker(bind=self.db_engine)
def create_documents_chunks(self, documents:List):
session = self.Session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
document_id=document.document_id,
content=document.content or "",
meta_info=document.meta_info or "",
gmt_created=datetime.now(),
gmt_modified=datetime.now()
)
for document in documents]
session.add_all(docs)
session.commit()
session.close()
def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20):
session = self.Session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
if query.doc_type is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
if query.doc_name is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
if query.meta_info is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
result = document_chunks.all()
return result
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
# session = self.Session()
# updated_space = session.merge(document)
# session.commit()
# return updated_space.id
# def delete_knowledge_document(self, document_id:int):
# cursor = self.conn.cursor()
# query = "DELETE FROM knowledge_document WHERE id = %s"
# cursor.execute(query, (document_id,))
# self.conn.commit()
# cursor.close()

View File

@ -0,0 +1,111 @@
import json
import os
import sys
from typing import List
from fastapi import APIRouter
from langchain.embeddings import HuggingFaceEmbeddings
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.server.api_v1.api_view_model import Result
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.server.knowledge.knowledge_service import KnowledgeService
from pilot.server.knowledge.request.knowledge_request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
)
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
CFG = Config()
router = APIRouter()
embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL])
knowledge_space_service = KnowledgeService()
@router.post("/knowledge/space/add")
def space_add(request: KnowledgeSpaceRequest):
print(f"/space/add params: {request}")
try:
knowledge_space_service.create_knowledge_space(request)
return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"space add error {e}")
@router.post("/knowledge/space/list")
def space_list(request: KnowledgeSpaceRequest):
print(f"/space/list params:")
try:
return Result.succ(knowledge_space_service.get_knowledge_space(request))
except Exception as e:
return Result.faild(code="E000X", msg=f"space list error {e}")
@router.post("/knowledge/{space_name}/document/add")
def document_add(space_name: str, request: KnowledgeDocumentRequest):
print(f"/document/add params: {space_name}, {request}")
try:
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)
return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/{space_name}/document/list")
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(knowledge_space_service.get_knowledge_documents(
space_name,
query_request
))
except Exception as e:
return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest):
print(f"Received params: {space_name}, {request}")
try:
knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids
)
Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
@router.post("/knowledge/{space_name}/chunk/list")
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
Result.succ(knowledge_space_service.get_document_chunks(
query_request
))
except Exception as e:
return Result.faild(code="E000X", msg=f"document chunk list error {e}")
@router.post("/knowledge/{vector_name}/query")
def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
print(f"Received params: {space_name}, {query_request}")
client = KnowledgeEmbedding(
model_name=embeddings, vector_store_config={"vector_store_name": space_name}
)
docs = client.similar_search(query_request.query, query_request.top_k)
res = [
KnowledgeQueryResponse(text=d.page_content, source=d.metadata["source"])
for d in docs
]
return {"response": res}

View File

@ -0,0 +1,87 @@
from datetime import datetime
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from pilot.configs.config import Config
CFG = Config()
Base = declarative_base()
class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document'
id = Column(Integer, primary_key=True)
doc_name = Column(String(100))
doc_type = Column(String(100))
space = Column(String(100))
chunk_size = Column(Integer)
status = Column(String(100))
last_sync = Column(String(100))
content = Column(Text)
vector_ids = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
echo=True)
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_document(self, document:KnowledgeDocumentEntity):
session = self.Session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
space=document.space,
chunk_size=0.0,
status=document.status,
last_sync=document.last_sync,
content=document.content or "",
vector_ids=document.vector_ids,
gmt_created=datetime.now(),
gmt_modified=datetime.now()
)
session.add(knowledge_document)
session.commit()
session.close()
def get_knowledge_documents(self, query, page=1, page_size=20):
session = self.Session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
if query.space is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
if query.status is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
result = knowledge_documents.all()
return result
def update_knowledge_document(self, document:KnowledgeDocumentEntity):
session = self.Session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
def delete_knowledge_document(self, document_id:int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,))
self.conn.commit()
cursor.close()

View File

@ -0,0 +1,173 @@
import threading
from datetime import datetime
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.logs import logger
from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
from pilot.server.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
from pilot.server.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
)
from enum import Enum
knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao()
CFG=Config()
class SyncStatus(Enum):
TODO = "TODO"
FAILED = "FAILED"
RUNNING = "RUNNING"
FINISHED = "FINISHED"
# @singleton
class KnowledgeService:
def __init__(self):
pass
"""create knowledge space"""
def create_knowledge_space(self, request: KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity(
name=request.name,
)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) > 0:
raise Exception(f"space name:{request.name} have already named")
knowledge_space_dao.create_knowledge_space(request)
return True
"""create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
space=space
)
documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0:
raise Exception(f"document name:{request.doc_name} have already named")
document = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
chunk_size=0,
status=SyncStatus.TODO.name,
last_sync=datetime.now(),
content="",
)
knowledge_document_dao.create_knowledge_document(document)
return True
"""get knowledge space"""
def get_knowledge_space(self, request:KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity(
name=request.name,
vector_type=request.vector_type,
owner=request.owner
)
return knowledge_space_dao.get_knowledge_space(query)
"""get knowledge get_knowledge_documents"""
def get_knowledge_documents(self, space, request:DocumentQueryRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
"""sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids):
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(file_path=doc.doc_name,
file_type="url",
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
})
chunk_docs = client.read()
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
thread.start()
#save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
doc_type=doc.doc_type,
document_id=doc.id,
content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now()
)
for chunk_doc in chunk_docs]
document_chunk_dao.create_documents_chunks(chunk_entities)
#update document status
# doc.status = SyncStatus.RUNNING.name
# doc.chunk_size = len(chunk_docs)
# doc.gmt_modified = datetime.now()
# knowledge_document_dao.update_knowledge_document(doc)
return True
"""update knowledge space"""
def update_knowledge_space(
self, space_id: int, space_request: KnowledgeSpaceRequest
):
knowledge_space_dao.update_knowledge_space(space_id, space_request)
"""delete knowledge space"""
def delete_knowledge_space(self, space_id: int):
return knowledge_space_dao.delete_knowledge_space(space_id)
"""get document chunks"""
def get_document_chunks(self, request:ChunkQueryRequest):
query = DocumentChunkEntity(
id=request.id,
document_id=request.document_id,
doc_name=request.doc_name,
doc_type=request.doc_type
)
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
def async_doc_embedding(self, client, chunk_docs, doc):
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name
doc.content = "embedding success"
doc.vector_ids = ",".join(vector_ids)
except Exception as e:
doc.status = SyncStatus.FAILED.name
doc.content = str(e)
return knowledge_document_dao.update_knowledge_document(doc)

View File

@ -0,0 +1,82 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, create_engine
from sqlalchemy.ext.declarative import declarative_base
from pilot.configs.config import Config
from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from sqlalchemy.orm import sessionmaker
CFG = Config()
Base = declarative_base()
class KnowledgeSpaceEntity(Base):
__tablename__ = 'knowledge_space'
id = Column(Integer, primary_key=True)
name = Column(String(100))
vector_type = Column(String(100))
desc = Column(String(100))
owner = Column(String(100))
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"KnowledgeSpaceEntity(id={self.id}, name='{self.name}', vector_type='{self.vector_type}', desc='{self.desc}', owner='{self.owner}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeSpaceDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_space(self, space:KnowledgeSpaceRequest):
session = self.Session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
vector_type=space.vector_type,
desc=space.desc,
owner=space.owner,
gmt_created=datetime.now(),
gmt_modified=datetime.now()
)
session.add(knowledge_space)
session.commit()
session.close()
def get_knowledge_space(self, query:KnowledgeSpaceEntity):
session = self.Session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
if query.name is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
if query.vector_type is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
if query.owner is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
if query.gmt_created is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
if query.gmt_modified is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
result = knowledge_spaces.all()
return result
def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity):
cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
self.conn.commit()
cursor.close()
def delete_knowledge_space(self, space_id:int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_space WHERE id = %s"
cursor.execute(query, (space_id,))
self.conn.commit()
cursor.close()

View File

@ -0,0 +1,74 @@
from typing import List
from pydantic import BaseModel
class KnowledgeQueryRequest(BaseModel):
"""query: knowledge query"""
query: str
"""top_k: return topK documents"""
top_k: int
class KnowledgeSpaceRequest(BaseModel):
"""name: knowledge space name"""
name: str = None
"""vector_type: vector type"""
vector_type: str = None
"""desc: description"""
desc: str = None
"""owner: owner"""
owner: str = None
class KnowledgeDocumentRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str
"""doc_type: doc type"""
doc_type: str
"""text_chunk_size: text_chunk_size"""
# text_chunk_size: int
class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str = None
"""doc_type: doc type"""
doc_type: str= None
"""status: status"""
status: str= None
"""page: page"""
page: int = 1
"""page_size: page size"""
page_size: int = 20
class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
doc_ids: List
class ChunkQueryRequest(BaseModel):
"""id: id"""
id: int = None
"""document_id: doc id"""
document_id: int = None
"""doc_name: doc path"""
doc_name: str = None
"""doc_type: doc type"""
doc_type: str = None
"""page: page"""
page: int = 1
"""page_size: page size"""
page_size: int = 20
class KnowledgeQueryResponse:
"""source: knowledge reference source"""
source: str
"""score: knowledge vector query similarity score"""
score: float = 0.0
"""text: raw text info"""
text: str

View File

@ -9,6 +9,7 @@ import sys
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
global_counter = 0 global_counter = 0
@ -22,10 +23,12 @@ from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader from pilot.model.loader import ModelLoader
from pilot.server.chat_adapter import get_llm_chat_adapter from pilot.server.chat_adapter import get_llm_chat_adapter
from knowledge.knowledge_controller import router
CFG = Config() CFG = Config()
class ModelWorker: class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1): def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"): if model_path.endswith("/"):
@ -103,7 +106,21 @@ worker = ModelWorker(
) )
app = FastAPI() app = FastAPI()
app.include_router(router)
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str

View File

@ -32,5 +32,10 @@ class ChromaStore(VectorStoreBase):
logger.info("ChromaStore load document") logger.info("ChromaStore load document")
texts = [doc.page_content for doc in documents] texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents] metadatas = [doc.metadata for doc in documents]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas) ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
self.vector_store_client.persist() self.vector_store_client.persist()
return ids
def delete_by_ids(self, ids):
collection = self.vector_store_client._collection
collection.delete(ids=ids)

View File

@ -16,7 +16,7 @@ class VectorStoreConnector:
def load_document(self, docs): def load_document(self, docs):
"""load document in vector database.""" """load document in vector database."""
self.client.load_document(docs) return self.client.load_document(docs)
def similar_search(self, docs, topk): def similar_search(self, docs, topk):
"""similar search in vector database.""" """similar search in vector database."""
@ -25,3 +25,6 @@ class VectorStoreConnector:
def vector_name_exists(self): def vector_name_exists(self):
"""is vector store name exist.""" """is vector store name exist."""
return self.client.vector_name_exists() return self.client.vector_name_exists()
def delete_by_ids(self, ids):
self.client.delete_by_ids(ids=ids)