mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-12 13:42:23 +00:00
Merge branch 'llm_framework' into dev_ty_06_end
# Conflicts: # pilot/openapi/api_v1/api_v1.py # pilot/server/dbgpt_server.py
This commit is contained in:
commit
6f8f182d1d
41
assets/schema/knowledge_management.sql
Normal file
41
assets/schema/knowledge_management.sql
Normal file
@ -0,0 +1,41 @@
|
||||
CREATE TABLE `knowledge_space` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`name` varchar(100) NOT NULL COMMENT 'knowledge space name',
|
||||
`vector_type` varchar(50) NOT NULL COMMENT 'vector type',
|
||||
`desc` varchar(500) NOT NULL COMMENT 'description',
|
||||
`owner` varchar(100) DEFAULT NULL COMMENT 'owner',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_name` (`name`) COMMENT 'index:idx_name'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge space table';
|
||||
|
||||
CREATE TABLE `knowledge_document` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||
`space` varchar(50) NOT NULL COMMENT 'knowledge space',
|
||||
`chunk_size` int NOT NULL COMMENT 'chunk size',
|
||||
`last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time',
|
||||
`status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED',
|
||||
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
||||
`result` TEXT NULL COMMENT 'knowledge content',
|
||||
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document table';
|
||||
|
||||
CREATE TABLE `document_chunk` (
|
||||
`id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id',
|
||||
`doc_name` varchar(100) NOT NULL COMMENT 'document path name',
|
||||
`doc_type` varchar(50) NOT NULL COMMENT 'doc type',
|
||||
`document_id` int NOT NULL COMMENT 'document parent id',
|
||||
`content` longtext NOT NULL COMMENT 'chunk content',
|
||||
`meta_info` varchar(200) NOT NULL COMMENT 'metadata info',
|
||||
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||
PRIMARY KEY (`id`),
|
||||
KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id'
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=100001 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='knowledge document chunk detail'
|
@ -19,7 +19,7 @@ const AgentPage = (props) => {
|
||||
});
|
||||
|
||||
const { history, handleChatSubmit } = useAgentChat({
|
||||
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
|
||||
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
|
||||
queryBody: {
|
||||
conv_uid: props.params?.agentId,
|
||||
chat_mode: props.searchParams?.scene || 'chat_normal',
|
||||
|
@ -16,7 +16,7 @@ const Item = styled(Sheet)(({ theme }) => ({
|
||||
|
||||
const Agents = () => {
|
||||
const { handleChatSubmit, history } = useAgentChat({
|
||||
queryAgentURL: `http://30.183.154.8:5000/v1/chat/completions`,
|
||||
queryAgentURL: `http://localhost:5000/v1/chat/completions`,
|
||||
});
|
||||
|
||||
const data = [
|
||||
|
@ -6,12 +6,11 @@ from typing import List
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import SpacyTextSplitter
|
||||
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.embedding_engine import SourceEmbedding, register
|
||||
from pilot.embedding_engine.EncodeTextLoader import EncodeTextLoader
|
||||
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -30,12 +29,20 @@ class MarkdownEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = EncodeTextLoader(self.file_path)
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
|
@ -4,7 +4,7 @@ from typing import List
|
||||
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import SpacyTextSplitter
|
||||
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.embedding_engine import SourceEmbedding, register
|
||||
@ -28,12 +28,24 @@ class PDFEmbedding(SourceEmbedding):
|
||||
# textsplitter = CHNDocumentSplitter(
|
||||
# pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
# )
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
# textsplitter = SpacyTextSplitter(
|
||||
# pipeline="zh_core_web_sm",
|
||||
# chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
# chunk_overlap=100,
|
||||
# )
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
|
@ -4,10 +4,11 @@ from typing import List
|
||||
|
||||
from langchain.document_loaders import UnstructuredPowerPointLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import SpacyTextSplitter
|
||||
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.embedding_engine import SourceEmbedding, register
|
||||
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -25,12 +26,24 @@ class PPTEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from ppt path."""
|
||||
loader = UnstructuredPowerPointLoader(self.file_path)
|
||||
textsplitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=200,
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
# textsplitter = SpacyTextSplitter(
|
||||
# pipeline="zh_core_web_sm",
|
||||
# chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
# chunk_overlap=200,
|
||||
# )
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
|
@ -2,12 +2,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
|
||||
from langchain.document_loaders import UnstructuredWordDocumentLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter, SpacyTextSplitter
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.embedding_engine import SourceEmbedding, register
|
||||
from pilot.embedding_engine.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -25,10 +25,19 @@ class WordEmbedding(SourceEmbedding):
|
||||
def read(self):
|
||||
"""Load from word path."""
|
||||
loader = UnstructuredWordDocumentLoader(self.file_path)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
text_splitter = SpacyTextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE,
|
||||
chunk_overlap=100,
|
||||
)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
|
@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
table_name = 'chat_history'
|
||||
table_name = "chat_history"
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
@ -29,10 +28,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
self.__init_chat_history_tables()
|
||||
|
||||
def __init_chat_history_tables(self):
|
||||
|
||||
# 检查表是否存在
|
||||
result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
[table_name]).fetchall()
|
||||
result = self.connect.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||
).fetchall()
|
||||
|
||||
if not result:
|
||||
# 如果表不存在,则创建新表
|
||||
@ -74,8 +73,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
conversations.append(_conversation_to_dic(once_message))
|
||||
cursor = self.connect.cursor()
|
||||
if context:
|
||||
cursor.execute("UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"UPDATE chat_history set messages=? where conv_uid=?",
|
||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id],
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"INSERT INTO chat_history(conv_uid, chat_mode, summary, user_name, messages)VALUES(?,?,?,?,?)",
|
||||
@ -85,13 +86,17 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def clear(self) -> None:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
def delete(self) -> bool:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
@ -100,7 +105,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if user_name:
|
||||
cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name])
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where user_name=? limit 20", [user_name]
|
||||
)
|
||||
else:
|
||||
cursor.execute("SELECT * FROM chat_history limit 20")
|
||||
# 获取查询结果字段名
|
||||
@ -118,7 +125,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id])
|
||||
cursor.execute(
|
||||
"SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]
|
||||
)
|
||||
context = cursor.fetchone()
|
||||
if context:
|
||||
if context[0]:
|
||||
|
0
pilot/openapi/knowledge/__init__.py
Normal file
0
pilot/openapi/knowledge/__init__.py
Normal file
@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
from pilot.configs.config import Config
|
||||
@ -83,6 +83,30 @@ class DocumentChunkDao:
|
||||
result = document_chunks.all()
|
||||
return result
|
||||
|
||||
def get_document_chunks_count(self, query: DocumentChunkEntity):
|
||||
session = self.Session()
|
||||
document_chunks = session.query(func.count(DocumentChunkEntity.id))
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
if query.document_id is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.document_id == query.document_id
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.meta_info is not None:
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.meta_info == query.meta_info
|
||||
)
|
||||
count = document_chunks.scalar()
|
||||
return count
|
||||
|
||||
# def update_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||
# session = self.Session()
|
||||
# updated_space = session.merge(document)
|
||||
|
@ -1,11 +1,13 @@
|
||||
import os
|
||||
import shutil
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi import APIRouter, File, UploadFile, Request, Form
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
|
||||
from pilot.openapi.api_v1.api_view_model import Result
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
@ -74,18 +76,43 @@ def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
async def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||
async def document_upload(
|
||||
space_name: str,
|
||||
doc_name: str = Form(...),
|
||||
doc_type: str = Form(...),
|
||||
doc_file: UploadFile = File(...),
|
||||
):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
with NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(file.read())
|
||||
tmp_path = tmp.name
|
||||
tmp_content = tmp.read()
|
||||
|
||||
return {"file_path": tmp_path, "file_content": tmp_content}
|
||||
Result.succ([])
|
||||
if doc_file:
|
||||
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name)):
|
||||
os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name))
|
||||
with NamedTemporaryFile(
|
||||
dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False
|
||||
) as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
tmp_path = tmp.name
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
),
|
||||
)
|
||||
request = KnowledgeDocumentRequest()
|
||||
request.doc_name = doc_name
|
||||
request.doc_type = doc_type
|
||||
request.content = (
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
|
||||
),
|
||||
)
|
||||
knowledge_space_service.create_knowledge_document(
|
||||
space=space_name, request=request
|
||||
)
|
||||
return Result.succ([])
|
||||
return Result.faild(code="E000X", msg=f"doc_file is None")
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
||||
return Result.faild(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync")
|
||||
|
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, create_engine, func
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
from pilot.configs.config import Config
|
||||
@ -92,15 +92,41 @@ class KnowledgeDocumentDao:
|
||||
result = knowledge_documents.all()
|
||||
return result
|
||||
|
||||
def get_knowledge_documents_count(self, query):
|
||||
session = self.Session()
|
||||
knowledge_documents = session.query(func.count(KnowledgeDocumentEntity.id))
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
if query.status is not None:
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.status == query.status
|
||||
)
|
||||
count = knowledge_documents.scalar()
|
||||
return count
|
||||
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
def delete_knowledge_document(self, document_id: int):
|
||||
cursor = self.conn.cursor()
|
||||
query = "DELETE FROM knowledge_document WHERE id = %s"
|
||||
cursor.execute(query, (document_id,))
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
#
|
||||
# def delete_knowledge_document(self, document_id: int):
|
||||
# cursor = self.conn.cursor()
|
||||
# query = "DELETE FROM knowledge_document WHERE id = %s"
|
||||
# cursor.execute(query, (document_id,))
|
||||
# self.conn.commit()
|
||||
# cursor.close()
|
||||
|
@ -25,6 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
from pilot.openapi.knowledge.request.knowledge_response import (
|
||||
ChunkQueryResponse,
|
||||
DocumentQueryResponse,
|
||||
)
|
||||
|
||||
knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
@ -72,6 +76,7 @@ class KnowledgeService:
|
||||
status=SyncStatus.TODO.name,
|
||||
last_sync=datetime.now(),
|
||||
content=request.content,
|
||||
result="",
|
||||
)
|
||||
knowledge_document_dao.create_knowledge_document(document)
|
||||
return True
|
||||
@ -93,9 +98,13 @@ class KnowledgeService:
|
||||
space=space,
|
||||
status=request.status,
|
||||
)
|
||||
return knowledge_document_dao.get_knowledge_documents(
|
||||
res = DocumentQueryResponse()
|
||||
res.data = knowledge_document_dao.get_knowledge_documents(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
res.total = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
|
||||
@ -106,6 +115,8 @@ class KnowledgeService:
|
||||
space=space_name,
|
||||
)
|
||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||
if doc.status == SyncStatus.RUNNING.name or doc.status == SyncStatus.FINISHED.name:
|
||||
raise Exception(f" doc:{doc.doc_name} status is {doc.status}, can not sync")
|
||||
client = KnowledgeEmbedding(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
@ -164,9 +175,13 @@ class KnowledgeService:
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
)
|
||||
return document_chunk_dao.get_document_chunks(
|
||||
res = ChunkQueryResponse()
|
||||
res.data = document_chunk_dao.get_document_chunks(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
res.total = document_chunk_dao.get_document_chunks_count(query)
|
||||
res.page = request.page
|
||||
return res
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
logger.info(
|
||||
|
0
pilot/openapi/knowledge/request/__init__.py
Normal file
0
pilot/openapi/knowledge/request/__init__.py
Normal file
@ -1,6 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from fastapi import UploadFile
|
||||
|
||||
|
||||
class KnowledgeQueryRequest(BaseModel):
|
||||
@ -26,11 +27,14 @@ class KnowledgeSpaceRequest(BaseModel):
|
||||
class KnowledgeDocumentRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
|
||||
doc_name: str
|
||||
doc_name: str = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str
|
||||
doc_type: str = None
|
||||
"""content: content"""
|
||||
content: str = None
|
||||
"""content: content"""
|
||||
source: str = None
|
||||
|
||||
"""text_chunk_size: text_chunk_size"""
|
||||
# text_chunk_size: int
|
||||
|
||||
|
23
pilot/openapi/knowledge/request/knowledge_response.py
Normal file
23
pilot/openapi/knowledge/request/knowledge_response.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChunkQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""total: total size"""
|
||||
total: int = None
|
||||
"""page: current page"""
|
||||
page: int = None
|
||||
|
||||
|
||||
class DocumentQueryResponse(BaseModel):
|
||||
"""data: data"""
|
||||
|
||||
data: List = None
|
||||
"""total: total size"""
|
||||
total: int = None
|
||||
"""page: current page"""
|
||||
page: int = None
|
@ -122,7 +122,7 @@ class BaseOutputParser(ABC):
|
||||
def __extract_json(slef, s):
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
@ -130,7 +130,7 @@ class BaseOutputParser(ABC):
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i: j + 1]
|
||||
return s[i : j + 1]
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
@ -147,9 +147,9 @@ class BaseOutputParser(ABC):
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json"):]
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```"):]
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
@ -158,9 +158,9 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
|
@ -60,10 +60,10 @@ class BaseChat(ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
@ -172,11 +172,18 @@ class BaseChat(ABC):
|
||||
print("[TEST: output]:", rsp_str)
|
||||
|
||||
### output parse
|
||||
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
|
||||
self.prompt_template.sep)
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
rsp_str, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
### model result deal
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
result = self.do_action(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
@ -236,7 +243,9 @@ class BaseChat(ABC):
|
||||
system_convs = self.current_message.get_system_conv()
|
||||
system_text = ""
|
||||
for system_conv in system_convs:
|
||||
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
system_text += (
|
||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
)
|
||||
return system_text
|
||||
|
||||
def __load_user_message(self):
|
||||
@ -250,13 +259,16 @@ class BaseChat(ABC):
|
||||
example_text = ""
|
||||
if self.prompt_template.example_selector:
|
||||
for round_conv in self.prompt_template.example_selector.examples():
|
||||
for round_message in round_conv['messages']:
|
||||
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||
for round_message in round_conv["messages"]:
|
||||
if not round_message["type"] in [
|
||||
SystemMessage.type,
|
||||
ViewMessage.type,
|
||||
]:
|
||||
example_text += (
|
||||
round_message['type']
|
||||
+ ":"
|
||||
+ round_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
round_message["type"]
|
||||
+ ":"
|
||||
+ round_message["data"]["content"]
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
return example_text
|
||||
|
||||
@ -268,37 +280,46 @@ class BaseChat(ABC):
|
||||
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
||||
)
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
for first_message in self.history_message[0]['messages']:
|
||||
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
||||
for first_message in self.history_message[0]["messages"]:
|
||||
if not first_message["type"] in [
|
||||
ViewMessage.type,
|
||||
SystemMessage.type,
|
||||
]:
|
||||
history_text += (
|
||||
first_message['type']
|
||||
+ ":"
|
||||
+ first_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
first_message["type"]
|
||||
+ ":"
|
||||
+ first_message["data"]["content"]
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for round_conv in self.history_message[-index:]:
|
||||
for round_message in round_conv['messages']:
|
||||
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||
for round_message in round_conv["messages"]:
|
||||
if not round_message["type"] in [
|
||||
SystemMessage.type,
|
||||
ViewMessage.type,
|
||||
]:
|
||||
history_text += (
|
||||
round_message['type']
|
||||
+ ":"
|
||||
+ round_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
round_message["type"]
|
||||
+ ":"
|
||||
+ round_message["data"]["content"]
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
else:
|
||||
### user all history
|
||||
for conversation in self.history_message:
|
||||
for message in conversation['messages']:
|
||||
for message in conversation["messages"]:
|
||||
### histroy message not have promot and view info
|
||||
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||
if not message["type"] in [
|
||||
SystemMessage.type,
|
||||
ViewMessage.type,
|
||||
]:
|
||||
history_text += (
|
||||
message['type']
|
||||
+ ":"
|
||||
+ message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
message["type"]
|
||||
+ ":"
|
||||
+ message["data"]["content"]
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
return history_text
|
||||
|
@ -9,6 +9,7 @@ from pilot.scene.chat_db.auto_execute.chat import ChatWithDbAutoExecute
|
||||
from pilot.scene.chat_knowledge.url.chat import ChatUrlKnowledge
|
||||
from pilot.scene.chat_knowledge.custom.chat import ChatNewKnowledge
|
||||
from pilot.scene.chat_knowledge.default.chat import ChatDefaultKnowledge
|
||||
from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||
from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ from pilot.configs.model_config import (
|
||||
LOGDIR,
|
||||
)
|
||||
|
||||
from pilot.scene.chat_knowledge.default.prompt import prompt
|
||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
|
||||
CFG = Config()
|
||||
|
@ -98,9 +98,10 @@ class OnceConversation:
|
||||
system_convs.append(message)
|
||||
return system_convs
|
||||
|
||||
|
||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
start_str: str = ""
|
||||
if hasattr(once, 'start_date') and once.start_date:
|
||||
if hasattr(once, "start_date") and once.start_date:
|
||||
if isinstance(once.start_date, datetime):
|
||||
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
|
@ -23,15 +23,23 @@ from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router
|
||||
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("in order to avoid chroma db atexit problem")
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
@ -55,23 +63,27 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
|
||||
|
||||
app.include_router(knowledge_router)
|
||||
app.include_router(api_v1)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
parser.add_argument(
|
||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||
parser.add_argument("--port", type=int, default=5000)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
@ -9,7 +9,8 @@ import sys
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
global_counter = 0
|
||||
@ -41,11 +42,11 @@ class ModelWorker:
|
||||
|
||||
if not isinstance(self.model, str):
|
||||
if hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_sequence_length"
|
||||
self.model.config, "max_sequence_length"
|
||||
):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_position_embeddings"
|
||||
self.model.config, "max_position_embeddings"
|
||||
):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
@ -60,22 +61,22 @@ class ModelWorker:
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
(
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
@ -107,23 +108,23 @@ worker = ModelWorker(
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
from pilot.openapi.knowledge.knowledge_controller import router
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:3000",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# from pilot.openapi.knowledge.knowledge_controller import router
|
||||
#
|
||||
# app.include_router(router)
|
||||
#
|
||||
# origins = [
|
||||
# "http://localhost",
|
||||
# "http://localhost:8000",
|
||||
# "http://localhost:3000",
|
||||
# ]
|
||||
#
|
||||
# app.add_middleware(
|
||||
# CORSMiddleware,
|
||||
# allow_origins=origins,
|
||||
# allow_credentials=True,
|
||||
# allow_methods=["*"],
|
||||
# allow_headers=["*"],
|
||||
# )
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
|
@ -40,6 +40,7 @@ def server_init(args):
|
||||
cfg = Config()
|
||||
|
||||
from pilot.server.llmserver import worker
|
||||
|
||||
worker.start_check()
|
||||
load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
Loading…
Reference in New Issue
Block a user