Merge branch 'llm_framework' into dev_ty_06_end

# Conflicts:
#	pilot/memory/chat_history/duckdb_history.py
#	pilot/openapi/api_v1/api_v1.py
#	pilot/scene/base.py
#	pilot/scene/base_chat.py
#	pilot/scene/chat_execution/example.py
This commit is contained in:
tuyang.yhj 2023-06-28 11:40:22 +08:00
commit caa1a41065
46 changed files with 634 additions and 291 deletions

View File

@ -2,14 +2,60 @@
import { useRouter, useSearchParams } from 'next/navigation'
import React, { useState, useEffect } from 'react'
import { Button, Table } from '@/lib/mui'
import {
Button,
Table,
Sheet,
Modal,
Box,
Stack,
Input,
styled
} from '@/lib/mui'
import moment from 'moment'
import { message } from 'antd'
const stepsOfAddingDocument = [
'Choose a Datasource type',
'Setup the Datasource'
]
const documentTypeList = [
{
type: 'text',
title: 'Text',
subTitle: 'Paste some text'
},
{
type: 'webPage',
title: 'Web Page',
subTitle: 'Crawl text from a web page'
},
{
type: 'file',
title: 'File',
subTitle: 'It can be: PDF, CSV, JSON, Text, PowerPoint, Word, Excel'
}
]
const Item = styled(Sheet)(({ theme }) => ({
width: '50%',
backgroundColor:
theme.palette.mode === 'dark' ? theme.palette.background.level1 : '#fff',
...theme.typography.body2,
padding: theme.spacing(1),
textAlign: 'center',
borderRadius: 4,
color: theme.vars.palette.text.secondary
}))
const Documents = () => {
const router = useRouter()
const spaceName = useSearchParams().get('name')
const [isAddDocumentModalShow, setIsAddDocumentModalShow] =
useState<boolean>(false)
const [activeStep, setActiveStep] = useState<number>(0)
const [documents, setDocuments] = useState<any>([])
const [webPageUrl, setWebPageUrl] = useState<string>('')
const [documentName, setDocumentName] = useState<string>('')
useEffect(() => {
async function fetchDocuments() {
const res = await fetch(
@ -31,6 +77,19 @@ const Documents = () => {
}, [])
return (
<div className="p-4">
<Sheet
sx={{
display: 'flex',
flexDirection: 'row-reverse'
}}
>
<Button
variant="outlined"
onClick={() => setIsAddDocumentModalShow(true)}
>
+ Add Datasource
</Button>
</Sheet>
<Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}>
<thead>
<tr>
@ -70,9 +129,9 @@ const Documents = () => {
)
const data = await res.json()
if (data.success) {
message.success('success');
message.success('success')
} else {
message.error(data.err_msg || 'failed');
message.error(data.err_msg || 'failed')
}
}}
>
@ -95,6 +154,137 @@ const Documents = () => {
))}
</tbody>
</Table>
<Modal
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
'z-index': 1000
}}
open={isAddDocumentModalShow}
onClose={() => setIsAddDocumentModalShow(false)}
>
<Sheet
variant="outlined"
sx={{
width: 800,
borderRadius: 'md',
p: 3,
boxShadow: 'lg'
}}
>
<Box sx={{ width: '100%' }}>
<Stack spacing={2} direction="row">
{stepsOfAddingDocument.map((item: any, index: number) => (
<Item
key={item}
sx={{ fontWeight: activeStep === index ? 'bold' : '' }}
>
{item}
</Item>
))}
</Stack>
</Box>
{activeStep === 0 ? (
<>
<Box sx={{ margin: '30px auto' }}>
{documentTypeList.map((item: any) => (
<Sheet
key={item.type}
sx={{
boxSizing: 'border-box',
height: '80px',
padding: '12px',
display: 'flex',
flexDirection: 'column',
justifyContent: 'space-between',
border: '1px solid gray',
borderRadius: '6px',
marginBottom: '20px',
cursor: 'pointer'
}}
onClick={() => {
if (item.type === 'webPage') {
setActiveStep(1)
}
}}
>
<Sheet sx={{ fontSize: '20px', fontWeight: 'bold' }}>
{item.title}
</Sheet>
<Sheet>{item.subTitle}</Sheet>
</Sheet>
))}
</Box>
</>
) : (
<>
<Box sx={{ margin: '30px auto' }}>
Name:
<Input
placeholder="Please input the name"
onChange={(e: any) => setDocumentName(e.target.value)}
sx={{ marginBottom: '20px' }}
/>
Web Page URL:
<Input
placeholder="Please input the Web Page URL"
onChange={(e: any) => setWebPageUrl(e.target.value)}
/>
</Box>
<Button
onClick={async () => {
if (documentName === '') {
message.error('Please input the name')
return
}
if (webPageUrl === '') {
message.error('Please input the Web Page URL')
return
}
const res = await fetch(
`http://localhost:8000/knowledge/${spaceName}/document/add`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
doc_name: documentName,
content: webPageUrl,
doc_type: 'URL'
})
}
)
const data = await res.json()
if (data.success) {
message.success('success')
setIsAddDocumentModalShow(false)
const res = await fetch(
`http://localhost:8000/knowledge/${spaceName}/document/list`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({})
}
)
const data = await res.json()
if (data.success) {
setDocuments(data.data)
}
} else {
message.error(data.err_msg || 'failed')
}
}}
>
Finish
</Button>
</>
)}
</Sheet>
</Modal>
</div>
)
}

View File

@ -1,8 +1,7 @@
'use client'
import { useRouter } from 'next/navigation'
import type { ProFormInstance } from '@ant-design/pro-components'
import React, { useState, useRef, useEffect } from 'react'
import React, { useState, useEffect } from 'react'
import { message } from 'antd'
import {
Modal,
@ -31,16 +30,33 @@ const stepsOfAddingSpace = [
'Choose a Datasource type',
'Setup the Datasource'
]
const documentTypeList = [
{
type: 'text',
title: 'Text',
subTitle: 'Paste some text'
},
{
type: 'webPage',
title: 'Web Page',
subTitle: 'Crawl text from a web page'
},
{
type: 'file',
title: 'File',
subTitle: 'It can be: PDF, CSV, JSON, Text, PowerPoint, Word, Excel'
}
]
const Index = () => {
const router = useRouter()
const formRef = useRef<ProFormInstance>()
const [activeStep, setActiveStep] = useState<number>(0)
const [knowledgeSpaceList, setKnowledgeSpaceList] = useState<any>([])
const [isAddKnowledgeSpaceModalShow, setIsAddKnowledgeSpaceModalShow] =
useState<boolean>(false)
const [knowledgeSpaceName, setKnowledgeSpaceName] = useState<string>('')
const [webPageUrl, setWebPageUrl] = useState<string>('')
const [documentName, setDocumentName] = useState<string>('')
useEffect(() => {
async function fetchData() {
const res = await fetch('http://localhost:8000/knowledge/space/list', {
@ -59,15 +75,21 @@ const Index = () => {
}, [])
return (
<>
<div className="page-header p-4">
<div className="page-header-title">Knowledge Spaces</div>
<Sheet sx={{
display: "flex",
justifyContent: "space-between"
}} className="p-4">
<Sheet sx={{
fontSize: '30px',
fontWeight: 'bold'
}}>Knowledge Spaces</Sheet>
<Button
onClick={() => setIsAddKnowledgeSpaceModalShow(true)}
variant="outlined"
>
+ New Knowledge Space
</Button>
</div>
</Sheet>
<div className="page-body p-4">
<Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}>
<thead>
@ -100,7 +122,6 @@ const Index = () => {
</Table>
</div>
<Modal
title="Add Knowledge Space"
sx={{
display: 'flex',
justifyContent: 'center',
@ -191,14 +212,42 @@ const Index = () => {
) : activeStep === 1 ? (
<>
<Box sx={{ margin: '30px auto' }}>
<Button variant="outlined" onClick={() => setActiveStep(2)}>
Web Page
</Button>
{documentTypeList.map((item: any) => (
<Sheet
key={item.type}
sx={{
boxSizing: 'border-box',
height: '80px',
padding: '12px',
display: 'flex',
flexDirection: 'column',
justifyContent: 'space-between',
border: '1px solid gray',
borderRadius: '6px',
marginBottom: '20px',
cursor: 'pointer'
}}
onClick={() => {
if (item.type === 'webPage') {
setActiveStep(2);
}
}}
>
<Sheet sx={{ fontSize: '20px', fontWeight: 'bold' }}>{item.title}</Sheet>
<Sheet>{item.subTitle}</Sheet>
</Sheet>
))}
</Box>
</>
) : (
<>
<Box sx={{ margin: '30px auto' }}>
Name:
<Input
placeholder="Please input the name"
onChange={(e: any) => setDocumentName(e.target.value)}
sx={{ marginBottom: '20px'}}
/>
Web Page URL:
<Input
placeholder="Please input the Web Page URL"
@ -207,6 +256,14 @@ const Index = () => {
</Box>
<Button
onClick={async () => {
if (documentName === '') {
message.error('Please input the name');
return;
}
if (webPageUrl === '') {
message.error('Please input the Web Page URL');
return;
}
const res = await fetch(
`http://localhost:8000/knowledge/${knowledgeSpaceName}/document/add`,
{
@ -215,7 +272,8 @@ const Index = () => {
'Content-Type': 'application/json'
},
body: JSON.stringify({
doc_name: webPageUrl,
doc_name: documentName,
content: webPageUrl,
doc_type: 'URL'
})
}
@ -235,24 +293,6 @@ const Index = () => {
)}
</Sheet>
</Modal>
<style jsx>{`
.page-header {
display: flex;
justify-content: space-between;
.page-header-title {
font-size: 30px;
font-weight: bold;
}
}
.datasource-type-wrap {
height: 100px;
line-height: 100px;
border: 1px solid black;
border-radius: 20px;
margin-bottom: 20px;
cursor: pointer;
}
`}</style>
</>
)
}

View File

@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
elif hasattr(obj, '__dict__'):
elif hasattr(obj, "__dict__"):
return obj.__dict__
else:
return json.JSONEncoder.default(self, obj)
return json.JSONEncoder.default(self, obj)

View File

@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load:
print("not auto load_native_plugins")
return
def load_from_git(cfg: Config):
print("async load_native_plugins")
branch_name = cfg.plugins_git_branch
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try:
session = requests.Session()
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
headers={'Authorization': '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'})
response = session.get(
url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime('%Y%m%d%H%M%S')
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.

View File

@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
THREE = auto()
FOUR = auto()
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"

View File

@ -92,7 +92,7 @@ class Config(metaclass=Singleton):
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")

View File

@ -6,7 +6,7 @@ import numpy as np
from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar
from pyecharts import options as opts
from test_cls_1 import TestBase,Test1
from test_cls_1 import TestBase, Test1
from test_cls_2 import Test2
CFG = Config()
@ -60,21 +60,21 @@ CFG = Config()
# if __name__ == "__main__":
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
if __name__ == "__main__":
test1 = Test1()
@ -83,4 +83,4 @@ if __name__ == "__main__":
test1.test()
test2.write()
test1.test()
test2.test()
test2.test()

View File

@ -4,9 +4,9 @@ from test_cls_base import TestBase
class Test1(TestBase):
mode:str = "456"
mode: str = "456"
def write(self):
self.test_values.append("x")
self.test_values.append("y")
self.test_values.append("g")

View File

@ -3,13 +3,15 @@ from pydantic import BaseModel
from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase):
test_2_values:List = []
mode:str = "789"
test_2_values: List = []
mode: str = "789"
def write(self):
self.test_values.append(1)
self.test_values.append(2)
self.test_values.append(3)
self.test_2_values.append("x")
self.test_2_values.append("y")
self.test_2_values.append("z")
self.test_2_values.append("z")

View File

@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC):
test_values: List = []
mode:str = "123"
mode: str = "123"
def test(self):
print(self.__class__.__name__ + ":" )
print(self.__class__.__name__ + ":")
print(self.test_values)
print(self.mode)
print(self.mode)

View File

@ -39,7 +39,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
return get_knowledge_embedding(
self.knowledge_type, self.knowledge_source, self.vector_store_config
)
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
return vector_client.vector_name_exists()
def delete_by_ids(self, ids):
vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, self.vector_store_config
)
vector_client.delete_by_ids(ids=ids)

View File

@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
def clear(self) -> None:
"""Clear session memory from the local file"""
def conv_list(self, user_name:str=None) -> None:
def conv_list(self, user_name: str = None) -> None:
"""get user's conversation list"""
pass

View File

@ -11,7 +11,7 @@ from pilot.scene.message import (
conversation_from_dict,
conversations_to_dict,
)
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
CFG = Config()
@ -19,7 +19,6 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
histroies_map = FixedSizeDict(100)
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.histroies_map.update({chat_session_id: []})

View File

@ -1,4 +1,4 @@
from .base import Cache
from .disk_cache import DiskCache
from .memory_cache import InMemoryCache
from .gpt_cache import GPTCache
from .gpt_cache import GPTCache

View File

@ -3,8 +3,8 @@ import hashlib
from typing import Any, Dict
from abc import ABC, abstractmethod
class Cache(ABC):
def create(self, key: str) -> bool:
pass
@ -24,4 +24,4 @@ class Cache(ABC):
@abstractmethod
def __contains__(self, key: str) -> bool:
"""see if we can return a cached value for the passed key"""
pass
pass

View File

@ -3,15 +3,15 @@ import diskcache
import platformdirs
from pilot.model.cache import Cache
class DiskCache(Cache):
"""DiskCache is a cache that uses diskcache lib.
https://github.com/grantjenks/python-diskcache
https://github.com/grantjenks/python-diskcache
"""
def __init__(self, llm_name: str):
self._diskcache = diskcache.Cache(
os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
)
os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
)
def __getitem__(self, key: str) -> str:
@ -22,6 +22,6 @@ class DiskCache(Cache):
def __contains__(self, key: str) -> bool:
return key in self._diskcache
def clear(self):
self._diskcache.clear()
self._diskcache.clear()

View File

@ -9,22 +9,23 @@ try:
except ImportError:
pass
class GPTCache(Cache):
"""
GPTCache is a semantic cache that uses
"""
GPTCache is a semantic cache that uses
"""
def __init__(self, cache) -> None:
"""GPT Cache is a semantic cache that uses GPTCache lib."""
if isinstance(cache, str):
_cache = Cache()
init_similar_cache(
data_dir=os.path.join(
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
),
cache_obj=_cache
cache_obj=_cache,
)
else:
_cache = cache
@ -41,4 +42,4 @@ class GPTCache(Cache):
return get(key) is not None
def create(self, llm: str, **kwargs: Dict[str, Any]) -> str:
pass
pass

View File

@ -1,24 +1,23 @@
from typing import Dict, Any
from pilot.model.cache import Cache
class InMemoryCache(Cache):
def __init__(self) -> None:
"Initialize that stores things in memory."
self._cache: Dict[str, Any] = {}
def create(self, key: str) -> bool:
pass
pass
def clear(self):
return self._cache.clear()
def __setitem__(self, key: str, value: str) -> None:
self._cache[key] = value
def __getitem__(self, key: str) -> str:
return self._cache[key]
def __contains__(self, key: str) -> bool:
return self._cache.get(key, None) is not None
def __contains__(self, key: str) -> bool:
return self._cache.get(key, None) is not None

View File

@ -2,7 +2,7 @@ import uuid
import json
import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
@ -16,12 +16,13 @@ from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageV
from pilot.configs.config import Config
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR)
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage)
from pilot.scene.base_message import BaseMessage
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
@ -42,19 +43,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
def __get_conv_user_message(conversations: dict):
messages = conversations['messages']
messages = conversations["messages"]
for item in messages:
if item['type'] == "human":
return item['data']['content']
if item["type"] == "human":
return item["data"]["content"]
return ""
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(response: Response, user_id: str = None):
# 设置CORS头部信息
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET'
response.headers['Access-Control-Request-Headers'] = 'content-type'
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET"
response.headers["Access-Control-Request-Headers"] = "content-type"
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
@ -65,26 +66,44 @@ async def dialogue_list(response: Response, user_id: str = None):
conversations = json.loads(messages)
first_conv: OnceConversation = conversations[0]
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv['chat_mode'])
conv_vo: ConversationVo = ConversationVo(
conv_uid=conv_uid,
user_input=__get_conv_user_message(first_conv),
chat_mode=first_conv["chat_mode"],
)
dialogues.append(conv_vo)
return Result[ConversationVo].succ(dialogues)
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
async def dialogue_scenes():
scene_vos: List[ChatSceneVo] = []
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
new_modes: List[ChatScene] = [
ChatScene.ChatDb,
ChatScene.ChatData,
ChatScene.ChatDashboard,
ChatScene.ChatKnowledge,
ChatScene.ChatExecution,
]
for scene in new_modes:
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
if not scene.value in [
ChatScene.ChatNormal.value,
ChatScene.InnerChatDBSummary.value,
]:
scene_vo = ChatSceneVo(
chat_scene=scene.value,
scene_name=scene.name,
param_title="Selection Param",
)
scene_vos.append(scene_vo)
return Result.succ(scene_vos)
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
):
unique_id = uuid.uuid1()
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
@ -92,7 +111,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
params:dict = {}
params: dict = {}
for name in dbs:
params.update({name: name})
return params
@ -106,11 +125,16 @@ def plugins_select_info():
def knowledge_list():
"""return knowledge space list"""
params: dict = {}
request = KnowledgeSpaceRequest()
return knowledge_service.get_knowledge_space(request)
spaces = knowledge_service.get_knowledge_space(request)
for space in spaces:
params.update({space.name: space.name})
return params
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
if ChatScene.ChatWithDbQA.value == chat_mode:
return Result.succ(get_db_list())
@ -126,14 +150,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
return Result.succ(None)
@router.post('/v1/chat/dialogue/delete')
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_mem = DuckdbHistoryMemory(con_uid)
history_mem.delete()
return Result.succ(None)
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
async def dialogue_history_messages(con_uid: str):
print(f"dialogue_history_messages:{con_uid}")
message_vos: List[MessageVo] = []
@ -142,12 +166,14 @@ async def dialogue_history_messages(con_uid: str):
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
once_message_vos = [
message2Vo(element, once["chat_order"]) for element in once["messages"]
]
message_vos.extend(once_message_vos)
return Result.succ(message_vos)
@router.post('/v1/chat/completions')
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
global model_semaphore, global_counter
@ -157,7 +183,9 @@ async def chat_completions(dialogue: ConversationVo = Body()):
await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
raise StopAsyncIteration(
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
)
chat_param = {
"chat_session_id": dialogue.conv_uid,

View File

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
T = TypeVar('T')
T = TypeVar("T")
class Result(Generic[T], BaseModel):
@ -24,15 +24,17 @@ class Result(Generic[T], BaseModel):
class ChatSceneVo(BaseModel):
chat_scene: str = Field(..., description="chat_scene")
scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title")
chat_scene: str = Field(..., description="chat_scene")
scene_name: str = Field(..., description="chat_scene name show for user")
param_title: str = Field(..., description="chat_scene required parameter title")
class ConversationVo(BaseModel):
"""
dialogue_uid
"""
conv_uid: str = Field(..., description="dialogue uid")
conv_uid: str = Field(..., description="dialogue uid")
"""
user input
"""
@ -44,7 +46,7 @@ class ConversationVo(BaseModel):
"""
the scene of chat
"""
chat_mode: str = Field(..., description="the scene of chat ")
chat_mode: str = Field(..., description="the scene of chat ")
"""
chat scene select param
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
select_param: str = None
class MessageVo(BaseModel):
"""
role that sends out the current message
role that sends out the current message
"""
role: str
"""
current message
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
time the current message was sent
"""
time_stamp: Any = None

View File

@ -10,8 +10,10 @@ from pilot.configs.config import Config
CFG = Config()
Base = declarative_base()
class DocumentChunkEntity(Base):
__tablename__ = 'document_chunk'
__tablename__ = "document_chunk"
id = Column(Integer, primary_key=True)
document_id = Column(Integer)
doc_name = Column(String(100))
@ -29,43 +31,55 @@ class DocumentChunkDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
echo=True)
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine)
def create_documents_chunks(self, documents:List):
def create_documents_chunks(self, documents: List):
session = self.Session()
docs = [
DocumentChunkEntity(
doc_name=document.doc_name,
doc_type=document.doc_type,
document_id=document.document_id,
content=document.content or "",
meta_info=document.meta_info or "",
gmt_created=datetime.now(),
gmt_modified=datetime.now()
doc_name=document.doc_name,
doc_type=document.doc_type,
document_id=document.document_id,
content=document.content or "",
meta_info=document.meta_info or "",
gmt_created=datetime.now(),
gmt_modified=datetime.now(),
)
for document in documents]
for document in documents
]
session.add_all(docs)
session.commit()
session.close()
def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20):
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
session = self.Session()
document_chunks = session.query(DocumentChunkEntity)
if query.id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
if query.document_id is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
document_chunks = document_chunks.filter(
DocumentChunkEntity.document_id == query.document_id
)
if query.doc_type is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_type == query.doc_type
)
if query.doc_name is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
document_chunks = document_chunks.filter(
DocumentChunkEntity.doc_name == query.doc_name
)
if query.meta_info is not None:
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
document_chunks = document_chunks.filter(
DocumentChunkEntity.meta_info == query.meta_info
)
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
page_size
)
result = document_chunks.all()
return result

View File

@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeQueryRequest,
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
KnowledgeQueryResponse,
KnowledgeDocumentRequest,
DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
)
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
def document_list(space_name: str, query_request: DocumentQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(knowledge_space_service.get_knowledge_documents(
space_name,
query_request
))
return Result.succ(
knowledge_space_service.get_knowledge_documents(space_name, query_request)
)
except Exception as e:
return Result.faild(code="E000X", msg=f"document list error {e}")
@router.post("/knowledge/{space_name}/document/upload")
def document_sync(space_name: str, file: UploadFile = File(...)):
async def document_sync(space_name: str, file: UploadFile = File(...)):
print(f"/document/upload params: {space_name}")
try:
with NamedTemporaryFile(delete=False) as tmp:
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids
)
Result.succ([])
return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document sync error {e}")
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
def document_list(space_name: str, query_request: ChunkQueryRequest):
print(f"/document/list params: {space_name}, {query_request}")
try:
return Result.succ(knowledge_space_service.get_document_chunks(
query_request
))
return Result.succ(knowledge_space_service.get_document_chunks(query_request))
except Exception as e:
return Result.faild(code="E000X", msg=f"document chunk list error {e}")

View File

@ -12,7 +12,7 @@ Base = declarative_base()
class KnowledgeDocumentEntity(Base):
__tablename__ = 'knowledge_document'
__tablename__ = "knowledge_document"
id = Column(Integer, primary_key=True)
doc_name = Column(String(100))
doc_type = Column(String(100))
@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base):
status = Column(String(100))
last_sync = Column(String(100))
content = Column(Text)
result = Column(Text)
vector_ids = Column(Text)
gmt_created = Column(DateTime)
gmt_modified = Column(DateTime)
def __repr__(self):
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class KnowledgeDocumentDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
echo=True)
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_document(self, document:KnowledgeDocumentEntity):
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
knowledge_document = KnowledgeDocumentEntity(
doc_name=document.doc_name,
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
status=document.status,
last_sync=document.last_sync,
content=document.content or "",
result=document.result or "",
vector_ids=document.vector_ids,
gmt_created=datetime.now(),
gmt_modified=datetime.now()
gmt_modified=datetime.now(),
)
session.add(knowledge_document)
session.commit()
@ -60,28 +63,42 @@ class KnowledgeDocumentDao:
session = self.Session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.id == query.id
)
if query.doc_name is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_name == query.doc_name
)
if query.doc_type is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.doc_type == query.doc_type
)
if query.space is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.space == query.space
)
if query.status is not None:
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
knowledge_documents = knowledge_documents.filter(
KnowledgeDocumentEntity.status == query.status
)
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
knowledge_documents = knowledge_documents.order_by(
KnowledgeDocumentEntity.id.desc()
)
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
page_size
)
result = knowledge_documents.all()
return result
def update_knowledge_document(self, document:KnowledgeDocumentEntity):
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
session = self.Session()
updated_space = session.merge(document)
session.commit()
return updated_space.id
def delete_knowledge_document(self, document_id:int):
def delete_knowledge_document(self, document_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_document WHERE id = %s"
cursor.execute(query, (document_id,))

View File

@ -4,17 +4,24 @@ from datetime import datetime
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.logs import logger
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
from pilot.openapi.knowledge.document_chunk_dao import (
DocumentChunkEntity,
DocumentChunkDao,
)
from pilot.openapi.knowledge.knowledge_document_dao import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
from pilot.openapi.knowledge.knowledge_space_dao import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
from pilot.openapi.knowledge.request.knowledge_request import (
KnowledgeSpaceRequest,
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
KnowledgeDocumentRequest,
DocumentQueryRequest,
ChunkQueryRequest,
)
from enum import Enum
@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao()
CFG=Config()
CFG = Config()
class SyncStatus(Enum):
@ -53,10 +60,7 @@ class KnowledgeService:
"""create knowledge document"""
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
space=space
)
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
documents = knowledge_document_dao.get_knowledge_documents(query)
if len(documents) > 0:
raise Exception(f"document name:{request.doc_name} have already named")
@ -74,26 +78,27 @@ class KnowledgeService:
"""get knowledge space"""
def get_knowledge_space(self, request:KnowledgeSpaceRequest):
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
query = KnowledgeSpaceEntity(
name=request.name,
vector_type=request.vector_type,
owner=request.owner
name=request.name, vector_type=request.vector_type, owner=request.owner
)
return knowledge_space_dao.get_knowledge_space(query)
"""get knowledge get_knowledge_documents"""
def get_knowledge_documents(self, space, request:DocumentQueryRequest):
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
query = KnowledgeDocumentEntity(
doc_name=request.doc_name,
doc_type=request.doc_type,
space=space,
status=request.status,
)
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
return knowledge_document_dao.get_knowledge_documents(
query, page=request.page, page_size=request.page_size
)
"""sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids):
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
@ -101,12 +106,14 @@ class KnowledgeService:
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
client = KnowledgeEmbedding(knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
})
client = KnowledgeEmbedding(
knowledge_source=doc.content,
knowledge_type=doc.doc_type.upper(),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config={
"vector_store_name": space_name,
},
)
chunk_docs = client.read()
# update document status
doc.status = SyncStatus.RUNNING.name
@ -114,9 +121,12 @@ class KnowledgeService:
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
# async doc embeddings
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
thread = threading.Thread(
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
)
thread.start()
#save chunk details
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
# save chunk details
chunk_entities = [
DocumentChunkEntity(
doc_name=doc.doc_name,
@ -125,9 +135,10 @@ class KnowledgeService:
content=chunk_doc.page_content,
meta_info=str(chunk_doc.metadata),
gmt_created=datetime.now(),
gmt_modified=datetime.now()
gmt_modified=datetime.now(),
)
for chunk_doc in chunk_docs]
for chunk_doc in chunk_docs
]
document_chunk_dao.create_documents_chunks(chunk_entities)
return True
@ -145,26 +156,30 @@ class KnowledgeService:
return knowledge_space_dao.delete_knowledge_space(space_id)
"""get document chunks"""
def get_document_chunks(self, request:ChunkQueryRequest):
def get_document_chunks(self, request: ChunkQueryRequest):
query = DocumentChunkEntity(
id=request.id,
document_id=request.document_id,
doc_name=request.doc_name,
doc_type=request.doc_type
doc_type=request.doc_type,
)
return document_chunk_dao.get_document_chunks(
query, page=request.page, page_size=request.page_size
)
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
def async_doc_embedding(self, client, chunk_docs, doc):
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
logger.info(
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = client.knowledge_embedding_batch(chunk_docs)
doc.status = SyncStatus.FINISHED.name
doc.content = "embedding success"
doc.result = "document embedding success"
doc.vector_ids = ",".join(vector_ids)
logger.info(f"async document embedding, success:{doc.doc_name}")
except Exception as e:
doc.status = SyncStatus.FAILED.name
doc.content = str(e)
doc.result = "document embedding failed" + str(e)
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)

View File

@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
CFG = Config()
Base = declarative_base()
class KnowledgeSpaceEntity(Base):
__tablename__ = 'knowledge_space'
__tablename__ = "knowledge_space"
id = Column(Integer, primary_key=True)
name = Column(String(100))
vector_type = Column(String(100))
@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base):
class KnowledgeSpaceDao:
def __init__(self):
database = "knowledge_management"
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
self.db_engine = create_engine(
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
self.Session = sessionmaker(bind=self.db_engine)
def create_knowledge_space(self, space:KnowledgeSpaceRequest):
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
session = self.Session()
knowledge_space = KnowledgeSpaceEntity(
name=space.name,
@ -38,43 +43,61 @@ class KnowledgeSpaceDao:
desc=space.desc,
owner=space.owner,
gmt_created=datetime.now(),
gmt_modified=datetime.now()
gmt_modified=datetime.now(),
)
session.add(knowledge_space)
session.commit()
session.close()
def get_knowledge_space(self, query:KnowledgeSpaceEntity):
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
session = self.Session()
knowledge_spaces = session.query(KnowledgeSpaceEntity)
if query.id is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.id == query.id
)
if query.name is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.name == query.name
)
if query.vector_type is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.vector_type == query.vector_type
)
if query.desc is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.desc == query.desc
)
if query.owner is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.owner == query.owner
)
if query.gmt_created is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_created == query.gmt_created
)
if query.gmt_modified is not None:
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
knowledge_spaces = knowledge_spaces.filter(
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
)
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
knowledge_spaces = knowledge_spaces.order_by(
KnowledgeSpaceEntity.gmt_created.desc()
)
result = knowledge_spaces.all()
return result
def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity):
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
cursor = self.conn.cursor()
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
cursor.execute(
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
)
self.conn.commit()
cursor.close()
def delete_knowledge_space(self, space_id:int):
def delete_knowledge_space(self, space_id: int):
cursor = self.conn.cursor()
query = "DELETE FROM knowledge_space WHERE id = %s"
cursor.execute(query, (space_id,))

View File

@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel):
"""doc_type: doc type"""
doc_type: str
"""content: content"""
content: str
content: str = None
"""text_chunk_size: text_chunk_size"""
# text_chunk_size: int
class DocumentQueryRequest(BaseModel):
"""doc_name: doc path"""
doc_name: str = None
"""doc_type: doc type"""
doc_type: str= None
doc_type: str = None
"""status: status"""
status: str= None
status: str = None
"""page: page"""
page: int = 1
"""page_size: page size"""
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
doc_ids: List
class ChunkQueryRequest(BaseModel):
"""id: id"""
id: int = None
"""document_id: doc id"""
document_id: int = None

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC):
examples: List[List]
use_example: bool = False

View File

@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:

View File

@ -1,11 +1,13 @@
from enum import Enum
class Scene:
def __init__(self, code, describe, is_inner):
self.code = code
self.describe = describe
self.is_inner = is_inner
class ChatScene(Enum):
ChatWithDbExecute = "chat_with_db_execute"
ChatWithDbQA = "chat_with_db_qa"
@ -24,4 +26,3 @@ class ChatScene(Enum):
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene)

View File

@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
)
CFG = Config()
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
report_name: str
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input, report_name
):
def __init__(self, chat_session_id, db_name, user_input, report_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,
@ -51,7 +52,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": "" #TODO
"supported_chat_type": "" # TODO
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
return input_values
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
# TODO 修复流程
print(str(e))
chart_datas.append(chart_data)
report_data.conv_uid = self.chat_session_id
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
report_data.charts = chart_datas
return report_data

View File

@ -12,11 +12,7 @@ class ChartData(BaseModel):
class ReportData(BaseModel):
conv_uid:str
template_name:str
template_introduce:str
conv_uid: str
template_name: str
template_introduce: str
charts: List[ChartData]

View File

@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR
class ChartItem(NamedTuple):
sql: str
title:str
title: str
thoughts: str
showcase:str
showcase: str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
response = json.loads(clean_str)
chart_items = List[ChartItem]
for item in response:
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
chart_items.append(
ChartItem(
item["sql"], item["title"], item["thoughts"], item["showcase"]
)
)
return chart_items
def parse_view_response(self, speak, data) -> str:

View File

@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
Ensure the response is correct json and can be parsed by Python json.loads
"""
RESPONSE_FORMAT = [{
"sql": "data analysis SQL",
"title": "Data Analysis Title",
"showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis"
}]
RESPONSE_FORMAT = [
{
"sql": "data analysis SQL",
"title": "Data Analysis Title",
"showcase": "What type of charts to show",
"thoughts": "Current thinking and value of data analysis",
}
]
PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input
):
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbExecute,

View File

@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, db_name, user_input
):
def __init__(self, chat_session_id, db_name, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatWithDbQA,
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
"table_info": table_info,
}
return input_values

View File

@ -34,7 +34,6 @@ RESPONSE_FORMAT = {
}
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
@ -47,8 +46,10 @@ prompt = PromptTemplate(
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
example=example
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
example=example,
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(
self, chat_session_id, user_input, knowledge_name
):
def __init__(self, chat_session_id, user_input, knowledge_name):
""" """
super().__init__(
chat_mode=ChatScene.ChatNewKnowledge,
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatNewKnowledge.value

View File

@ -29,7 +29,7 @@ class ChatDefaultKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input):
def __init__(self, chat_session_id, user_input):
""" """
super().__init__(
chat_mode=ChatScene.ChatDefaultKnowledge,
@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
)
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatDefaultKnowledge.value

View File

@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value

View File

@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
input_values = {"context": context, "question": self.current_user_input}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatUrlKnowledge.value

View File

@ -29,7 +29,7 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, knowledge_space):
def __init__(self, chat_session_id, user_input, knowledge_space):
""" """
super().__init__(
chat_mode=ChatScene.ChatKnowledge,
@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
)
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value

View File

@ -85,7 +85,6 @@ class OnceConversation:
self.messages.clear()
self.session_id = None
def get_user_message(self):
for once in self.messages:
if isinstance(once, HumanMessage):

View File

@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
load_dotenv(verbose=True, override=True)
del load_dotenv

View File

@ -122,7 +122,7 @@ app.add_middleware(
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
allow_headers=["*"],
)

View File

@ -56,7 +56,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
@ -107,12 +107,16 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue,
]
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
*args,
**kwargs,
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
@ -360,14 +364,18 @@ def http_bot(
response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
state.messages[-1][-1] =msg
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
state.messages[-1][-1] = msg
chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e:
print(traceback.format_exc())
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
state.messages[-1][
-1
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument(
"-new", "--new", action="store_true", help="enable new http mode"
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
@ -702,27 +714,24 @@ if __name__ == "__main__":
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
server_init(args)
if args.new:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
concurrency_count=args.concurrency_count,
status_update_rate=10,
api_open=False,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)

View File

@ -38,7 +38,9 @@ class KnowledgeEmbedding:
return self.knowledge_embedding_client.read_batch()
def init_knowledge_embedding(self):
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
return get_knowledge_embedding(
self.file_type.upper(), self.file_path, self.vector_store_config
)
def similar_search(self, text, topk):
vector_client = VectorStoreConnector(