Merge branch 'llm_framework' into dev_ty_06_end

# Conflicts:
#	pilot/memory/chat_history/duckdb_history.py
#	pilot/openapi/api_v1/api_v1.py
#	pilot/scene/base.py
#	pilot/scene/base_chat.py
#	pilot/scene/chat_execution/example.py
This commit is contained in:
tuyang.yhj 2023-06-28 11:40:22 +08:00
commit caa1a41065
46 changed files with 634 additions and 291 deletions

View File

@ -2,14 +2,60 @@
import { useRouter, useSearchParams } from 'next/navigation' import { useRouter, useSearchParams } from 'next/navigation'
import React, { useState, useEffect } from 'react' import React, { useState, useEffect } from 'react'
import { Button, Table } from '@/lib/mui' import {
Button,
Table,
Sheet,
Modal,
Box,
Stack,
Input,
styled
} from '@/lib/mui'
import moment from 'moment' import moment from 'moment'
import { message } from 'antd' import { message } from 'antd'
const stepsOfAddingDocument = [
'Choose a Datasource type',
'Setup the Datasource'
]
const documentTypeList = [
{
type: 'text',
title: 'Text',
subTitle: 'Paste some text'
},
{
type: 'webPage',
title: 'Web Page',
subTitle: 'Crawl text from a web page'
},
{
type: 'file',
title: 'File',
subTitle: 'It can be: PDF, CSV, JSON, Text, PowerPoint, Word, Excel'
}
]
const Item = styled(Sheet)(({ theme }) => ({
width: '50%',
backgroundColor:
theme.palette.mode === 'dark' ? theme.palette.background.level1 : '#fff',
...theme.typography.body2,
padding: theme.spacing(1),
textAlign: 'center',
borderRadius: 4,
color: theme.vars.palette.text.secondary
}))
const Documents = () => { const Documents = () => {
const router = useRouter() const router = useRouter()
const spaceName = useSearchParams().get('name') const spaceName = useSearchParams().get('name')
const [isAddDocumentModalShow, setIsAddDocumentModalShow] =
useState<boolean>(false)
const [activeStep, setActiveStep] = useState<number>(0)
const [documents, setDocuments] = useState<any>([]) const [documents, setDocuments] = useState<any>([])
const [webPageUrl, setWebPageUrl] = useState<string>('')
const [documentName, setDocumentName] = useState<string>('')
useEffect(() => { useEffect(() => {
async function fetchDocuments() { async function fetchDocuments() {
const res = await fetch( const res = await fetch(
@ -31,6 +77,19 @@ const Documents = () => {
}, []) }, [])
return ( return (
<div className="p-4"> <div className="p-4">
<Sheet
sx={{
display: 'flex',
flexDirection: 'row-reverse'
}}
>
<Button
variant="outlined"
onClick={() => setIsAddDocumentModalShow(true)}
>
+ Add Datasource
</Button>
</Sheet>
<Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}> <Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}>
<thead> <thead>
<tr> <tr>
@ -70,9 +129,9 @@ const Documents = () => {
) )
const data = await res.json() const data = await res.json()
if (data.success) { if (data.success) {
message.success('success'); message.success('success')
} else { } else {
message.error(data.err_msg || 'failed'); message.error(data.err_msg || 'failed')
} }
}} }}
> >
@ -95,6 +154,137 @@ const Documents = () => {
))} ))}
</tbody> </tbody>
</Table> </Table>
<Modal
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
'z-index': 1000
}}
open={isAddDocumentModalShow}
onClose={() => setIsAddDocumentModalShow(false)}
>
<Sheet
variant="outlined"
sx={{
width: 800,
borderRadius: 'md',
p: 3,
boxShadow: 'lg'
}}
>
<Box sx={{ width: '100%' }}>
<Stack spacing={2} direction="row">
{stepsOfAddingDocument.map((item: any, index: number) => (
<Item
key={item}
sx={{ fontWeight: activeStep === index ? 'bold' : '' }}
>
{item}
</Item>
))}
</Stack>
</Box>
{activeStep === 0 ? (
<>
<Box sx={{ margin: '30px auto' }}>
{documentTypeList.map((item: any) => (
<Sheet
key={item.type}
sx={{
boxSizing: 'border-box',
height: '80px',
padding: '12px',
display: 'flex',
flexDirection: 'column',
justifyContent: 'space-between',
border: '1px solid gray',
borderRadius: '6px',
marginBottom: '20px',
cursor: 'pointer'
}}
onClick={() => {
if (item.type === 'webPage') {
setActiveStep(1)
}
}}
>
<Sheet sx={{ fontSize: '20px', fontWeight: 'bold' }}>
{item.title}
</Sheet>
<Sheet>{item.subTitle}</Sheet>
</Sheet>
))}
</Box>
</>
) : (
<>
<Box sx={{ margin: '30px auto' }}>
Name:
<Input
placeholder="Please input the name"
onChange={(e: any) => setDocumentName(e.target.value)}
sx={{ marginBottom: '20px' }}
/>
Web Page URL:
<Input
placeholder="Please input the Web Page URL"
onChange={(e: any) => setWebPageUrl(e.target.value)}
/>
</Box>
<Button
onClick={async () => {
if (documentName === '') {
message.error('Please input the name')
return
}
if (webPageUrl === '') {
message.error('Please input the Web Page URL')
return
}
const res = await fetch(
`http://localhost:8000/knowledge/${spaceName}/document/add`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
doc_name: documentName,
content: webPageUrl,
doc_type: 'URL'
})
}
)
const data = await res.json()
if (data.success) {
message.success('success')
setIsAddDocumentModalShow(false)
const res = await fetch(
`http://localhost:8000/knowledge/${spaceName}/document/list`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({})
}
)
const data = await res.json()
if (data.success) {
setDocuments(data.data)
}
} else {
message.error(data.err_msg || 'failed')
}
}}
>
Finish
</Button>
</>
)}
</Sheet>
</Modal>
</div> </div>
) )
} }

View File

@ -1,8 +1,7 @@
'use client' 'use client'
import { useRouter } from 'next/navigation' import { useRouter } from 'next/navigation'
import type { ProFormInstance } from '@ant-design/pro-components' import React, { useState, useEffect } from 'react'
import React, { useState, useRef, useEffect } from 'react'
import { message } from 'antd' import { message } from 'antd'
import { import {
Modal, Modal,
@ -31,16 +30,33 @@ const stepsOfAddingSpace = [
'Choose a Datasource type', 'Choose a Datasource type',
'Setup the Datasource' 'Setup the Datasource'
] ]
const documentTypeList = [
{
type: 'text',
title: 'Text',
subTitle: 'Paste some text'
},
{
type: 'webPage',
title: 'Web Page',
subTitle: 'Crawl text from a web page'
},
{
type: 'file',
title: 'File',
subTitle: 'It can be: PDF, CSV, JSON, Text, PowerPoint, Word, Excel'
}
]
const Index = () => { const Index = () => {
const router = useRouter() const router = useRouter()
const formRef = useRef<ProFormInstance>()
const [activeStep, setActiveStep] = useState<number>(0) const [activeStep, setActiveStep] = useState<number>(0)
const [knowledgeSpaceList, setKnowledgeSpaceList] = useState<any>([]) const [knowledgeSpaceList, setKnowledgeSpaceList] = useState<any>([])
const [isAddKnowledgeSpaceModalShow, setIsAddKnowledgeSpaceModalShow] = const [isAddKnowledgeSpaceModalShow, setIsAddKnowledgeSpaceModalShow] =
useState<boolean>(false) useState<boolean>(false)
const [knowledgeSpaceName, setKnowledgeSpaceName] = useState<string>('') const [knowledgeSpaceName, setKnowledgeSpaceName] = useState<string>('')
const [webPageUrl, setWebPageUrl] = useState<string>('') const [webPageUrl, setWebPageUrl] = useState<string>('')
const [documentName, setDocumentName] = useState<string>('')
useEffect(() => { useEffect(() => {
async function fetchData() { async function fetchData() {
const res = await fetch('http://localhost:8000/knowledge/space/list', { const res = await fetch('http://localhost:8000/knowledge/space/list', {
@ -59,15 +75,21 @@ const Index = () => {
}, []) }, [])
return ( return (
<> <>
<div className="page-header p-4"> <Sheet sx={{
<div className="page-header-title">Knowledge Spaces</div> display: "flex",
justifyContent: "space-between"
}} className="p-4">
<Sheet sx={{
fontSize: '30px',
fontWeight: 'bold'
}}>Knowledge Spaces</Sheet>
<Button <Button
onClick={() => setIsAddKnowledgeSpaceModalShow(true)} onClick={() => setIsAddKnowledgeSpaceModalShow(true)}
variant="outlined" variant="outlined"
> >
+ New Knowledge Space + New Knowledge Space
</Button> </Button>
</div> </Sheet>
<div className="page-body p-4"> <div className="page-body p-4">
<Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}> <Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}>
<thead> <thead>
@ -100,7 +122,6 @@ const Index = () => {
</Table> </Table>
</div> </div>
<Modal <Modal
title="Add Knowledge Space"
sx={{ sx={{
display: 'flex', display: 'flex',
justifyContent: 'center', justifyContent: 'center',
@ -191,14 +212,42 @@ const Index = () => {
) : activeStep === 1 ? ( ) : activeStep === 1 ? (
<> <>
<Box sx={{ margin: '30px auto' }}> <Box sx={{ margin: '30px auto' }}>
<Button variant="outlined" onClick={() => setActiveStep(2)}> {documentTypeList.map((item: any) => (
Web Page <Sheet
</Button> key={item.type}
sx={{
boxSizing: 'border-box',
height: '80px',
padding: '12px',
display: 'flex',
flexDirection: 'column',
justifyContent: 'space-between',
border: '1px solid gray',
borderRadius: '6px',
marginBottom: '20px',
cursor: 'pointer'
}}
onClick={() => {
if (item.type === 'webPage') {
setActiveStep(2);
}
}}
>
<Sheet sx={{ fontSize: '20px', fontWeight: 'bold' }}>{item.title}</Sheet>
<Sheet>{item.subTitle}</Sheet>
</Sheet>
))}
</Box> </Box>
</> </>
) : ( ) : (
<> <>
<Box sx={{ margin: '30px auto' }}> <Box sx={{ margin: '30px auto' }}>
Name:
<Input
placeholder="Please input the name"
onChange={(e: any) => setDocumentName(e.target.value)}
sx={{ marginBottom: '20px'}}
/>
Web Page URL: Web Page URL:
<Input <Input
placeholder="Please input the Web Page URL" placeholder="Please input the Web Page URL"
@ -207,6 +256,14 @@ const Index = () => {
</Box> </Box>
<Button <Button
onClick={async () => { onClick={async () => {
if (documentName === '') {
message.error('Please input the name');
return;
}
if (webPageUrl === '') {
message.error('Please input the Web Page URL');
return;
}
const res = await fetch( const res = await fetch(
`http://localhost:8000/knowledge/${knowledgeSpaceName}/document/add`, `http://localhost:8000/knowledge/${knowledgeSpaceName}/document/add`,
{ {
@ -215,7 +272,8 @@ const Index = () => {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify({ body: JSON.stringify({
doc_name: webPageUrl, doc_name: documentName,
content: webPageUrl,
doc_type: 'URL' doc_type: 'URL'
}) })
} }
@ -235,24 +293,6 @@ const Index = () => {
)} )}
</Sheet> </Sheet>
</Modal> </Modal>
<style jsx>{`
.page-header {
display: flex;
justify-content: space-between;
.page-header-title {
font-size: 30px;
font-weight: bold;
}
}
.datasource-type-wrap {
height: 100px;
line-height: 100px;
border: 1px solid black;
border-radius: 20px;
margin-bottom: 20px;
cursor: pointer;
}
`}</style>
</> </>
) )
} }

View File

@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, set): if isinstance(obj, set):
return list(obj) return list(obj)
elif hasattr(obj, '__dict__'): elif hasattr(obj, "__dict__"):
return obj.__dict__ return obj.__dict__
else: else:
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)

View File

@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load: if not cfg.plugins_auto_load:
print("not auto load_native_plugins") print("not auto load_native_plugins")
return return
def load_from_git(cfg: Config): def load_from_git(cfg: Config):
print("async load_native_plugins") print("async load_native_plugins")
branch_name = cfg.plugins_git_branch branch_name = cfg.plugins_git_branch
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip" url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try: try:
session = requests.Session() session = requests.Session()
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name), response = session.get(
headers={'Authorization': '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'}) url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200: if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files: for file in files:
os.remove(file) os.remove(file)
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S') time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name) print(file_name)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
t.start() t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them. """Scan the plugins directory for plugins and loads them.

View File

@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
THREE = auto() THREE = auto()
FOUR = auto() FOUR = auto()
class ExampleType(Enum): class ExampleType(Enum):
ONE_SHOT = "one_shot" ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot" FEW_SHOT = "few_shot"

View File

@ -5,8 +5,8 @@ from test_cls_base import TestBase
class Test1(TestBase): class Test1(TestBase):
mode: str = "456" mode: str = "456"
def write(self): def write(self):
self.test_values.append("x") self.test_values.append("x")
self.test_values.append("y") self.test_values.append("y")
self.test_values.append("g") self.test_values.append("g")

View File

@ -3,9 +3,11 @@ from pydantic import BaseModel
from test_cls_base import TestBase from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase): class Test2(TestBase):
test_2_values: List = [] test_2_values: List = []
mode: str = "789" mode: str = "789"
def write(self): def write(self):
self.test_values.append(1) self.test_values.append(1)
self.test_values.append(2) self.test_values.append(2)

View File

@ -39,7 +39,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config) return get_knowledge_embedding(
self.knowledge_type, self.knowledge_source, self.vector_store_config
)
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
CFG.VECTOR_STORE_TYPE, self.vector_store_config CFG.VECTOR_STORE_TYPE, self.vector_store_config
) )
return vector_client.vector_name_exists() return vector_client.vector_name_exists()
def delete_by_ids(self, ids):
vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
vector_client.delete_by_ids(ids=ids)

View File

@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
def clear(self) -> None: def clear(self) -> None:
"""Clear session memory from the local file""" """Clear session memory from the local file"""
def conv_list(self, user_name: str = None) -> None: def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list""" """get user's conversation list"""
pass pass

View File

@ -19,7 +19,6 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory): class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = FixedSizeDict(100) histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str): def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []}) self.histroies_map.update({chat_session_id: []})

View File

@ -3,8 +3,8 @@ import hashlib
from typing import Any, Dict from typing import Any, Dict
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class Cache(ABC):
class Cache(ABC):
def create(self, key: str) -> bool: def create(self, key: str) -> bool:
pass pass

View File

@ -3,15 +3,15 @@ import diskcache
import platformdirs import platformdirs
from pilot.model.cache import Cache from pilot.model.cache import Cache
class DiskCache(Cache): class DiskCache(Cache):
"""DiskCache is a cache that uses diskcache lib. """DiskCache is a cache that uses diskcache lib.
https://github.com/grantjenks/python-diskcache https://github.com/grantjenks/python-diskcache
""" """
def __init__(self, llm_name: str): def __init__(self, llm_name: str):
self._diskcache = diskcache.Cache( self._diskcache = diskcache.Cache(
os.path.join( os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
)
) )
def __getitem__(self, key: str) -> str: def __getitem__(self, key: str) -> str:

View File

@ -9,6 +9,7 @@ try:
except ImportError: except ImportError:
pass pass
class GPTCache(Cache): class GPTCache(Cache):
""" """
@ -24,7 +25,7 @@ class GPTCache(Cache):
data_dir=os.path.join( data_dir=os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache" platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
), ),
cache_obj=_cache cache_obj=_cache,
) )
else: else:
_cache = cache _cache = cache

View File

@ -1,8 +1,8 @@
from typing import Dict, Any from typing import Dict, Any
from pilot.model.cache import Cache from pilot.model.cache import Cache
class InMemoryCache(Cache):
class InMemoryCache(Cache):
def __init__(self) -> None: def __init__(self) -> None:
"Initialize that stores things in memory." "Initialize that stores things in memory."
self._cache: Dict[str, Any] = {} self._cache: Dict[str, Any] = {}
@ -21,4 +21,3 @@ class InMemoryCache(Cache):
def __contains__(self, key: str) -> bool: def __contains__(self, key: str) -> bool:
return self._cache.get(key, None) is not None return self._cache.get(key, None) is not None

View File

@ -2,7 +2,7 @@ import uuid
import json import json
import asyncio import asyncio
import time import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -16,12 +16,13 @@ from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageV
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR) from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage) from pilot.scene.base_message import BaseMessage
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
@ -42,19 +43,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
def __get_conv_user_message(conversations: dict): def __get_conv_user_message(conversations: dict):
messages = conversations['messages'] messages = conversations["messages"]
for item in messages: for item in messages:
if item['type'] == "human": if item["type"] == "human":
return item['data']['content'] return item["data"]["content"]
return "" return ""
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo]) @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None): async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息 # 设置CORS头部信息
response.headers['Access-Control-Allow-Origin'] = '*' response.headers["Access-Control-Allow-Origin"] = "*"
response.headers['Access-Control-Allow-Methods'] = 'GET' response.headers["Access-Control-Allow-Methods"] = "GET"
response.headers['Access-Control-Request-Headers'] = 'content-type' response.headers["Access-Control-Request-Headers"] = "content-type"
dialogues: List = [] dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id) datas = DuckdbHistoryMemory.conv_list(user_id)
@ -65,26 +66,44 @@ async def dialogue_list(response: Response, user_id: str = None):
conversations = json.loads(messages) conversations = json.loads(messages)
first_conv: OnceConversation = conversations[0] first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv), conv_vo: ConversationVo = ConversationVo(
chat_mode=first_conv['chat_mode']) conv_uid=conv_uid,
user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv["chat_mode"],
)
dialogues.append(conv_vo) dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues) return Result[ConversationVo].succ(dialogues)
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) @router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes(): async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = [] scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] new_modes: List[ChatScene] = [
ChatScene.ChatDb,
ChatScene.ChatData,
ChatScene.ChatDashboard,
ChatScene.ChatKnowledge,
ChatScene.ChatExecution,
]
for scene in new_modes: for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: if not scene.value in [
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") ChatScene.ChatNormal.value,
ChatScene.InnerChatDBSummary.value,
]:
scene_vo = ChatSceneVo(
chat_scene=scene.value,
scene_name=scene.name,
param_title="Selection Param",
)
scene_vos.append(scene_vo) scene_vos.append(scene_vo)
return Result.succ(scene_vos) return Result.succ(scene_vos)
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None): async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1() unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -106,11 +125,16 @@ def plugins_select_info():
def knowledge_list(): def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest() request = KnowledgeSpaceRequest()
return knowledge_service.get_knowledge_space(request) spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
@router.post('/v1/chat/mode/params/list', response_model=Result[dict]) @router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value): async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatWithDbQA.value == chat_mode: if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list()) return Result.succ(get_db_list())
@ -126,14 +150,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
return Result.succ(None) return Result.succ(None)
@router.post('/v1/chat/dialogue/delete') @router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str): async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete() history_mem.delete()
return Result.succ(None) return Result.succ(None)
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo]) @router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str): async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}") print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = [] message_vos: List[MessageVo] = []
@ -142,12 +166,14 @@ async def dialogue_history_messages(con_uid: str):
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
for once in history_messages: for once in history_messages:
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']] once_message_vos = [
message2Vo(element, once["chat_order"]) for element in once["messages"]
]
message_vos.extend(once_message_vos) message_vos.extend(once_message_vos)
return Result.succ(message_vos) return Result.succ(message_vos)
@router.post('/v1/chat/completions') @router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()): async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
global model_semaphore, global_counter global model_semaphore, global_counter
@ -157,7 +183,9 @@ async def chat_completions(dialogue: ConversationVo = Body()):
await model_semaphore.acquire() await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode): if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) raise StopAsyncIteration(
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
)
chat_param = { chat_param = {
"chat_session_id": dialogue.conv_uid, "chat_session_id": dialogue.conv_uid,

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any from typing import TypeVar, Union, List, Generic, Any
T = TypeVar('T') T = TypeVar("T")
class Result(Generic[T], BaseModel): class Result(Generic[T], BaseModel):
@ -28,10 +28,12 @@ class ChatSceneVo(BaseModel):
scene_name: str = Field(..., description="chat_scene name show for user") scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title") param_title: str = Field(..., description="chat_scene required parameter title")
class ConversationVo(BaseModel): class ConversationVo(BaseModel):
""" """
dialogue_uid dialogue_uid
""" """
conv_uid: str = Field(..., description="dialogue uid") conv_uid: str = Field(..., description="dialogue uid")
""" """
user input user input
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
select_param: str = None select_param: str = None
class MessageVo(BaseModel): class MessageVo(BaseModel):
""" """
role that sends out the current message role that sends out the current message
""" """
role: str role: str
""" """
current message current message
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
time the current message was sent time the current message was sent
""" """
time_stamp: Any = None time_stamp: Any = None

View File

@ -10,8 +10,10 @@ from pilot.configs.config import Config
CFG = Config() CFG = Config()
Base = declarative_base() Base = declarative_base()
class DocumentChunkEntity(Base): class DocumentChunkEntity(Base):
__tablename__ = 'document_chunk' __tablename__ = "document_chunk"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
document_id = Column(Integer) document_id = Column(Integer)
doc_name = Column(String(100)) doc_name = Column(String(100))
@ -29,8 +31,9 @@ class DocumentChunkDao:
def __init__(self): def __init__(self):
database = "knowledge_management" database = "knowledge_management"
self.db_engine = create_engine( self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True) echo=True,
)
self.Session = sessionmaker(bind=self.db_engine) self.Session = sessionmaker(bind=self.db_engine)
def create_documents_chunks(self, documents: List): def create_documents_chunks(self, documents: List):
@ -43,9 +46,10 @@ class DocumentChunkDao:
content=document.content or "", content=document.content or "",
meta_info=document.meta_info or "", meta_info=document.meta_info or "",
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
for document in documents] for document in documents
]
session.add_all(docs) session.add_all(docs)
session.commit() session.commit()
session.close() session.close()
@ -56,16 +60,26 @@ class DocumentChunkDao:
if query.id is not None: if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id) document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None: if query.document_id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id) document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None: if query.doc_type is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type) document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None: if query.doc_name is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name) document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None: if query.meta_info is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info) document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc()) document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size) document_chunks = document_chunks.offset((page - 1) * page_size).limit(
page_size
)
result = document_chunks.all() result = document_chunks.all()
return result return result

View File

@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import ( from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest, KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, KnowledgeQueryResponse,
KnowledgeDocumentRequest,
DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
) )
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
def document_list(space_name: str, query_request: DocumentQueryRequest): def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}") print(f"/document/list params: {space_name}, {query_request}")
try: try:
return Result.succ(knowledge_space_service.get_knowledge_documents( return Result.succ(
space_name, knowledge_space_service.get_knowledge_documents(space_name, query_request)
query_request )
))
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document list error {e}") return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload") @router.post("/knowledge/{space_name}/document/upload")
def document_sync(space_name: str, file: UploadFile = File(...)): async def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}") print(f"/document/upload params: {space_name}")
try: try:
with NamedTemporaryFile(delete=False) as tmp: with NamedTemporaryFile(delete=False) as tmp:
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
knowledge_space_service.sync_knowledge_document( knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids space_name=space_name, doc_ids=request.doc_ids
) )
Result.succ([]) return Result.succ([])
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}") return Result.faild(code="E000X", msg=f"document sync error {e}")
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest): def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}") print(f"/document/list params: {space_name}, {query_request}")
try: try:
return Result.succ(knowledge_space_service.get_document_chunks( return Result.succ(knowledge_space_service.get_document_chunks(query_request))
query_request
))
except Exception as e: except Exception as e:
return Result.faild(code="E000X", msg=f"document chunk list error {e}") return Result.faild(code="E000X", msg=f"document chunk list error {e}")

View File

@ -12,7 +12,7 @@ Base = declarative_base()
class KnowledgeDocumentEntity(Base): class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document' __tablename__ = "knowledge_document"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
doc_name = Column(String(100)) doc_name = Column(String(100))
doc_type = Column(String(100)) doc_type = Column(String(100))
@ -21,20 +21,22 @@ class KnowledgeDocumentEntity(Base):
status = Column(String(100)) status = Column(String(100))
last_sync = Column(String(100)) last_sync = Column(String(100))
content = Column(Text) content = Column(Text)
result = Column(Text)
vector_ids = Column(Text) vector_ids = Column(Text)
gmt_created = Column(DateTime) gmt_created = Column(DateTime)
gmt_modified = Column(DateTime) gmt_modified = Column(DateTime)
def __repr__(self): def __repr__(self):
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao: class KnowledgeDocumentDao:
def __init__(self): def __init__(self):
database = "knowledge_management" database = "knowledge_management"
self.db_engine = create_engine( self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True) echo=True,
)
self.Session = sessionmaker(bind=self.db_engine) self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_document(self, document: KnowledgeDocumentEntity): def create_knowledge_document(self, document: KnowledgeDocumentEntity):
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
status=document.status, status=document.status,
last_sync=document.last_sync, last_sync=document.last_sync,
content=document.content or "", content=document.content or "",
result=document.result or "",
vector_ids=document.vector_ids, vector_ids=document.vector_ids,
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
session.add(knowledge_document) session.add(knowledge_document)
session.commit() session.commit()
@ -60,18 +63,32 @@ class KnowledgeDocumentDao:
session = self.Session() session = self.Session()
knowledge_documents = session.query(KnowledgeDocumentEntity) knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None: if query.id is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None: if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None: if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None: if query.space is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None: if query.status is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status) knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc()) knowledge_documents = knowledge_documents.order_by(
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size) KnowledgeDocumentEntity.id.desc()
)
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
page_size
)
result = knowledge_documents.all() result = knowledge_documents.all()
return result return result

View File

@ -4,17 +4,24 @@ from datetime import datetime
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger from pilot.logs import logger
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao from pilot.openapi.knowledge.document_chunk_dao import (
DocumentChunkEntity,
DocumentChunkDao,
)
from pilot.openapi.knowledge.knowledge_document_dao import ( from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao, KnowledgeDocumentDao,
KnowledgeDocumentEntity, KnowledgeDocumentEntity,
) )
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity from pilot.openapi.knowledge.knowledge_space_dao import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
from pilot.openapi.knowledge.request.knowledge_request import ( from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest, KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, KnowledgeDocumentRequest,
DocumentQueryRequest,
ChunkQueryRequest,
) )
from enum import Enum from enum import Enum
@ -53,10 +60,7 @@ class KnowledgeService:
"""create knowledge document""" """create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest): def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
doc_name=request.doc_name,
space=space
)
documents = knowledge_document_dao.get_knowledge_documents(query) documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0: if len(documents) > 0:
raise Exception(f"document name:{request.doc_name} have already named") raise Exception(f"document name:{request.doc_name} have already named")
@ -76,9 +80,7 @@ class KnowledgeService:
def get_knowledge_space(self, request: KnowledgeSpaceRequest): def get_knowledge_space(self, request: KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity( query = KnowledgeSpaceEntity(
name=request.name, name=request.name, vector_type=request.vector_type, owner=request.owner
vector_type=request.vector_type,
owner=request.owner
) )
return knowledge_space_dao.get_knowledge_space(query) return knowledge_space_dao.get_knowledge_space(query)
@ -91,9 +93,12 @@ class KnowledgeService:
space=space, space=space,
status=request.status, status=request.status,
) )
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size) return knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids): def sync_knowledge_document(self, space_name, doc_ids):
for doc_id in doc_ids: for doc_id in doc_ids:
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(
@ -101,12 +106,14 @@ class KnowledgeService:
space=space_name, space=space_name,
) )
doc = knowledge_document_dao.get_knowledge_documents(query)[0] doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(knowledge_source=doc.content, client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(), knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={ vector_store_config={
"vector_store_name": space_name, "vector_store_name": space_name,
}) },
)
chunk_docs = client.read() chunk_docs = client.read()
# update document status # update document status
doc.status = SyncStatus.RUNNING.name doc.status = SyncStatus.RUNNING.name
@ -114,8 +121,11 @@ class KnowledgeService:
doc.gmt_modified = datetime.now() doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc) knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings # async doc embeddings
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc)) thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start() thread.start()
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details # save chunk details
chunk_entities = [ chunk_entities = [
DocumentChunkEntity( DocumentChunkEntity(
@ -125,9 +135,10 @@ class KnowledgeService:
content=chunk_doc.page_content, content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata), meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
for chunk_doc in chunk_docs] for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities) document_chunk_dao.create_documents_chunks(chunk_entities)
return True return True
@ -145,26 +156,30 @@ class KnowledgeService:
return knowledge_space_dao.delete_knowledge_space(space_id) return knowledge_space_dao.delete_knowledge_space(space_id)
"""get document chunks""" """get document chunks"""
def get_document_chunks(self, request: ChunkQueryRequest): def get_document_chunks(self, request: ChunkQueryRequest):
query = DocumentChunkEntity( query = DocumentChunkEntity(
id=request.id, id=request.id,
document_id=request.document_id, document_id=request.document_id,
doc_name=request.doc_name, doc_name=request.doc_name,
doc_type=request.doc_type doc_type=request.doc_type,
)
return document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size
) )
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
def async_doc_embedding(self, client, chunk_docs, doc): def async_doc_embedding(self, client, chunk_docs, doc):
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}") logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try: try:
vector_ids = client.knowledge_embedding_batch(chunk_docs) vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name doc.status = SyncStatus.FINISHED.name
doc.content = "embedding success" doc.result = "document embedding success"
doc.vector_ids = ",".join(vector_ids) doc.vector_ids = ",".join(vector_ids)
logger.info(f"async document embedding, success:{doc.doc_name}")
except Exception as e: except Exception as e:
doc.status = SyncStatus.FAILED.name doc.status = SyncStatus.FAILED.name
doc.content = str(e) doc.result = "document embedding failed" + str(e)
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc) return knowledge_document_dao.update_knowledge_document(doc)

View File

@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
CFG = Config() CFG = Config()
Base = declarative_base() Base = declarative_base()
class KnowledgeSpaceEntity(Base): class KnowledgeSpaceEntity(Base):
__tablename__ = 'knowledge_space' __tablename__ = "knowledge_space"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String(100)) name = Column(String(100))
vector_type = Column(String(100)) vector_type = Column(String(100))
@ -27,7 +29,10 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao: class KnowledgeSpaceDao:
def __init__(self): def __init__(self):
database = "knowledge_management" database = "knowledge_management"
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True) self.db_engine = create_engine(
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine) self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_space(self, space: KnowledgeSpaceRequest): def create_knowledge_space(self, space: KnowledgeSpaceRequest):
@ -38,7 +43,7 @@ class KnowledgeSpaceDao:
desc=space.desc, desc=space.desc,
owner=space.owner, owner=space.owner,
gmt_created=datetime.now(), gmt_created=datetime.now(),
gmt_modified=datetime.now() gmt_modified=datetime.now(),
) )
session.add(knowledge_space) session.add(knowledge_space)
session.commit() session.commit()
@ -49,28 +54,46 @@ class KnowledgeSpaceDao:
session = self.Session() session = self.Session()
knowledge_spaces = session.query(KnowledgeSpaceEntity) knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None: if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.id == query.id
)
if query.name is not None: if query.name is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.name == query.name
)
if query.vector_type is not None: if query.vector_type is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.vector_type == query.vector_type
)
if query.desc is not None: if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.desc == query.desc
)
if query.owner is not None: if query.owner is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.owner == query.owner
)
if query.gmt_created is not None: if query.gmt_created is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_created == query.gmt_created
)
if query.gmt_modified is not None: if query.gmt_modified is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified) knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
)
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc()) knowledge_spaces = knowledge_spaces.order_by(
KnowledgeSpaceEntity.gmt_created.desc()
)
result = knowledge_spaces.all() result = knowledge_spaces.all()
return result return result
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity): def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
cursor = self.conn.cursor() cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s" query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id)) cursor.execute(
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
)
self.conn.commit() self.conn.commit()
cursor.close() cursor.close()

View File

@ -30,12 +30,14 @@ class KnowledgeDocumentRequest(BaseModel):
"""doc_type: doc type""" """doc_type: doc type"""
doc_type: str doc_type: str
"""content: content""" """content: content"""
content: str content: str = None
"""text_chunk_size: text_chunk_size""" """text_chunk_size: text_chunk_size"""
# text_chunk_size: int # text_chunk_size: int
class DocumentQueryRequest(BaseModel): class DocumentQueryRequest(BaseModel):
"""doc_name: doc path""" """doc_name: doc path"""
doc_name: str = None doc_name: str = None
"""doc_type: doc type""" """doc_type: doc type"""
doc_type: str = None doc_type: str = None
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel): class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids""" """doc_ids: doc ids"""
doc_ids: List doc_ids: List
class ChunkQueryRequest(BaseModel): class ChunkQueryRequest(BaseModel):
"""id: id""" """id: id"""
id: int = None id: int = None
"""document_id: doc id""" """document_id: doc id"""
document_id: int = None document_id: int = None

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC): class ExampleSelector(BaseModel, ABC):
examples: List[List] examples: List[List]
use_example: bool = False use_example: bool = False

View File

@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector from pilot.prompts.example_base import ExampleSelector
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.""" """Format a template using jinja2."""
try: try:

View File

@ -1,11 +1,13 @@
from enum import Enum from enum import Enum
class Scene: class Scene:
def __init__(self, code, describe, is_inner): def __init__(self, code, describe, is_inner):
self.code = code self.code = code
self.describe = describe self.describe = describe
self.is_inner = is_inner self.is_inner = is_inner
class ChatScene(Enum): class ChatScene(Enum):
ChatWithDbExecute = "chat_with_db_execute" ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa" ChatWithDbQA = "chat_with_db_qa"
@ -24,4 +26,3 @@ class ChatScene(Enum):
@staticmethod @staticmethod
def is_valid_mode(mode): def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene) return any(mode == item.value for item in ChatScene)

View File

@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
generate_htm_table, generate_htm_table,
) )
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
)
CFG = Config() CFG = Config()
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
report_name: str report_name: str
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, db_name, user_input, report_name):
self, chat_session_id, db_name, user_input, report_name
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbExecute, chat_mode=ChatScene.ChatWithDbExecute,
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
# TODO 修复流程 # TODO 修复流程
print(str(e)) print(str(e))
chart_datas.append(chart_data) chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id report_data.conv_uid = self.chat_session_id
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
report_data.charts = chart_datas report_data.charts = chart_datas
return report_data return report_data

View File

@ -16,7 +16,3 @@ class ReportData(BaseModel):
template_name: str template_name: str
template_introduce: str template_introduce: str
charts: List[ChartData] charts: List[ChartData]

View File

@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
response = json.loads(clean_str) response = json.loads(clean_str)
chart_items = List[ChartItem] chart_items = List[ChartItem]
for item in response: for item in response:
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"])) chart_items.append(
ChartItem(
item["sql"], item["title"], item["thoughts"], item["showcase"]
)
)
return chart_items return chart_items
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:

View File

@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
Ensure the response is correct json and can be parsed by Python json.loads Ensure the response is correct json and can be parsed by Python json.loads
""" """
RESPONSE_FORMAT = [{ RESPONSE_FORMAT = [
{
"sql": "data analysis SQL", "sql": "data analysis SQL",
"title": "Data Analysis Title", "title": "Data Analysis Title",
"showcase": "What type of charts to show", "showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis" "thoughts": "Current thinking and value of data analysis",
}] }
]
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, db_name, user_input):
self, chat_session_id, db_name, user_input
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbExecute, chat_mode=ChatScene.ChatWithDbExecute,

View File

@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, db_name, user_input):
self, chat_session_id, db_name, user_input
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbQA, chat_mode=ChatScene.ChatWithDbQA,
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
"table_info": table_info, "table_info": table_info,
} }
return input_values return input_values

View File

@ -34,7 +34,6 @@ RESPONSE_FORMAT = {
} }
EXAMPLE_TYPE = ExampleType.ONE_SHOT EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output ### Whether the model service is streaming output
@ -47,8 +46,10 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT), output_parser=PluginChatOutputParser(
example=example sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
example=example,
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__( def __init__(self, chat_session_id, user_input, knowledge_name):
self, chat_session_id, user_input, knowledge_name
):
""" """ """ """
super().__init__( super().__init__(
chat_mode=ChatScene.ChatNewKnowledge, chat_mode=ChatScene.ChatNewKnowledge,
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value return ChatScene.ChatNewKnowledge.value

View File

@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
) )
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatDefaultKnowledge.value return ChatScene.ChatDefaultKnowledge.value

View File

@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
} }
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value return ChatScene.InnerChatDBSummary.value

View File

@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input} input_values = {"context": context, "question": self.current_user_input}
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value return ChatScene.ChatUrlKnowledge.value

View File

@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
) )
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value return ChatScene.ChatKnowledge.value

View File

@ -85,7 +85,6 @@ class OnceConversation:
self.messages.clear() self.messages.clear()
self.session_id = None self.session_id = None
def get_user_message(self): def get_user_message(self):
for once in self.messages: for once in self.messages:
if isinstance(once, HumanMessage): if isinstance(once, HumanMessage):

View File

@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True) load_dotenv(verbose=True, override=True)
del load_dotenv del load_dotenv

View File

@ -122,7 +122,7 @@ app.add_middleware(
allow_origins=origins, allow_origins=origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"] allow_headers=["*"],
) )

View File

@ -107,12 +107,16 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue, add_knowledge_base_dialogue,
] ]
def swagger_monkey_patch(*args, **kwargs): def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html( return get_swagger_ui_html(
*args, **kwargs, *args,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', **kwargs,
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
) )
applications.get_swagger_ui_html = swagger_monkey_patch applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI() app = FastAPI()
@ -360,14 +364,18 @@ def http_bot(
response = chat.stream_call() response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
state.messages[-1][-1] = msg state.messages[-1][-1] = msg
chat.current_message.add_ai_message(msg) chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message) chat.memory.append(chat.current_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """ state.messages[-1][
-1
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) parser.add_argument(
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode') "--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument(
"-new", "--new", action="store_true", help="enable new http mode"
)
# old version server config # old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -702,27 +714,24 @@ if __name__ == "__main__":
parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true") parser.add_argument("--share", default=False, action="store_true")
# init server config # init server config
args = parser.parse_args() args = parser.parse_args()
server_init(args) server_init(args)
if args.new: if args.new:
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000) uvicorn.run(app, host="0.0.0.0", port=5000)
else: else:
### Compatibility mode starts the old version server by default ### Compatibility mode starts the old version server by default
demo = build_webdemo() demo = build_webdemo()
demo.queue( demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False concurrency_count=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch( ).launch(
server_name=args.host, server_name=args.host,
server_port=args.port, server_port=args.port,
share=args.share, share=args.share,
max_threads=200, max_threads=200,
) )

View File

@ -38,7 +38,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch() return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self): def init_knowledge_embedding(self):
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config) return get_knowledge_embedding(
self.file_type.upper(), self.file_path, self.vector_store_config
)
def similar_search(self, text, topk): def similar_search(self, text, topk):
vector_client = VectorStoreConnector( vector_client = VectorStoreConnector(