mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 04:44:14 +00:00
Merge branch 'llm_framework' into dev_ty_06_end
# Conflicts: # pilot/memory/chat_history/duckdb_history.py # pilot/openapi/api_v1/api_v1.py # pilot/scene/base.py # pilot/scene/base_chat.py # pilot/scene/chat_execution/example.py
This commit is contained in:
commit
caa1a41065
@ -2,14 +2,60 @@
|
||||
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { Button, Table } from '@/lib/mui'
|
||||
import {
|
||||
Button,
|
||||
Table,
|
||||
Sheet,
|
||||
Modal,
|
||||
Box,
|
||||
Stack,
|
||||
Input,
|
||||
styled
|
||||
} from '@/lib/mui'
|
||||
import moment from 'moment'
|
||||
import { message } from 'antd'
|
||||
|
||||
const stepsOfAddingDocument = [
|
||||
'Choose a Datasource type',
|
||||
'Setup the Datasource'
|
||||
]
|
||||
const documentTypeList = [
|
||||
{
|
||||
type: 'text',
|
||||
title: 'Text',
|
||||
subTitle: 'Paste some text'
|
||||
},
|
||||
{
|
||||
type: 'webPage',
|
||||
title: 'Web Page',
|
||||
subTitle: 'Crawl text from a web page'
|
||||
},
|
||||
{
|
||||
type: 'file',
|
||||
title: 'File',
|
||||
subTitle: 'It can be: PDF, CSV, JSON, Text, PowerPoint, Word, Excel'
|
||||
}
|
||||
]
|
||||
const Item = styled(Sheet)(({ theme }) => ({
|
||||
width: '50%',
|
||||
backgroundColor:
|
||||
theme.palette.mode === 'dark' ? theme.palette.background.level1 : '#fff',
|
||||
...theme.typography.body2,
|
||||
padding: theme.spacing(1),
|
||||
textAlign: 'center',
|
||||
borderRadius: 4,
|
||||
color: theme.vars.palette.text.secondary
|
||||
}))
|
||||
|
||||
const Documents = () => {
|
||||
const router = useRouter()
|
||||
const spaceName = useSearchParams().get('name')
|
||||
const [isAddDocumentModalShow, setIsAddDocumentModalShow] =
|
||||
useState<boolean>(false)
|
||||
const [activeStep, setActiveStep] = useState<number>(0)
|
||||
const [documents, setDocuments] = useState<any>([])
|
||||
const [webPageUrl, setWebPageUrl] = useState<string>('')
|
||||
const [documentName, setDocumentName] = useState<string>('')
|
||||
useEffect(() => {
|
||||
async function fetchDocuments() {
|
||||
const res = await fetch(
|
||||
@ -31,6 +77,19 @@ const Documents = () => {
|
||||
}, [])
|
||||
return (
|
||||
<div className="p-4">
|
||||
<Sheet
|
||||
sx={{
|
||||
display: 'flex',
|
||||
flexDirection: 'row-reverse'
|
||||
}}
|
||||
>
|
||||
<Button
|
||||
variant="outlined"
|
||||
onClick={() => setIsAddDocumentModalShow(true)}
|
||||
>
|
||||
+ Add Datasource
|
||||
</Button>
|
||||
</Sheet>
|
||||
<Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}>
|
||||
<thead>
|
||||
<tr>
|
||||
@ -70,9 +129,9 @@ const Documents = () => {
|
||||
)
|
||||
const data = await res.json()
|
||||
if (data.success) {
|
||||
message.success('success');
|
||||
message.success('success')
|
||||
} else {
|
||||
message.error(data.err_msg || 'failed');
|
||||
message.error(data.err_msg || 'failed')
|
||||
}
|
||||
}}
|
||||
>
|
||||
@ -95,6 +154,137 @@ const Documents = () => {
|
||||
))}
|
||||
</tbody>
|
||||
</Table>
|
||||
<Modal
|
||||
sx={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
'z-index': 1000
|
||||
}}
|
||||
open={isAddDocumentModalShow}
|
||||
onClose={() => setIsAddDocumentModalShow(false)}
|
||||
>
|
||||
<Sheet
|
||||
variant="outlined"
|
||||
sx={{
|
||||
width: 800,
|
||||
borderRadius: 'md',
|
||||
p: 3,
|
||||
boxShadow: 'lg'
|
||||
}}
|
||||
>
|
||||
<Box sx={{ width: '100%' }}>
|
||||
<Stack spacing={2} direction="row">
|
||||
{stepsOfAddingDocument.map((item: any, index: number) => (
|
||||
<Item
|
||||
key={item}
|
||||
sx={{ fontWeight: activeStep === index ? 'bold' : '' }}
|
||||
>
|
||||
{item}
|
||||
</Item>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
{activeStep === 0 ? (
|
||||
<>
|
||||
<Box sx={{ margin: '30px auto' }}>
|
||||
{documentTypeList.map((item: any) => (
|
||||
<Sheet
|
||||
key={item.type}
|
||||
sx={{
|
||||
boxSizing: 'border-box',
|
||||
height: '80px',
|
||||
padding: '12px',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
justifyContent: 'space-between',
|
||||
border: '1px solid gray',
|
||||
borderRadius: '6px',
|
||||
marginBottom: '20px',
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
onClick={() => {
|
||||
if (item.type === 'webPage') {
|
||||
setActiveStep(1)
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Sheet sx={{ fontSize: '20px', fontWeight: 'bold' }}>
|
||||
{item.title}
|
||||
</Sheet>
|
||||
<Sheet>{item.subTitle}</Sheet>
|
||||
</Sheet>
|
||||
))}
|
||||
</Box>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Box sx={{ margin: '30px auto' }}>
|
||||
Name:
|
||||
<Input
|
||||
placeholder="Please input the name"
|
||||
onChange={(e: any) => setDocumentName(e.target.value)}
|
||||
sx={{ marginBottom: '20px' }}
|
||||
/>
|
||||
Web Page URL:
|
||||
<Input
|
||||
placeholder="Please input the Web Page URL"
|
||||
onChange={(e: any) => setWebPageUrl(e.target.value)}
|
||||
/>
|
||||
</Box>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
if (documentName === '') {
|
||||
message.error('Please input the name')
|
||||
return
|
||||
}
|
||||
if (webPageUrl === '') {
|
||||
message.error('Please input the Web Page URL')
|
||||
return
|
||||
}
|
||||
const res = await fetch(
|
||||
`http://localhost:8000/knowledge/${spaceName}/document/add`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
doc_name: documentName,
|
||||
content: webPageUrl,
|
||||
doc_type: 'URL'
|
||||
})
|
||||
}
|
||||
)
|
||||
const data = await res.json()
|
||||
if (data.success) {
|
||||
message.success('success')
|
||||
setIsAddDocumentModalShow(false)
|
||||
const res = await fetch(
|
||||
`http://localhost:8000/knowledge/${spaceName}/document/list`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({})
|
||||
}
|
||||
)
|
||||
const data = await res.json()
|
||||
if (data.success) {
|
||||
setDocuments(data.data)
|
||||
}
|
||||
} else {
|
||||
message.error(data.err_msg || 'failed')
|
||||
}
|
||||
}}
|
||||
>
|
||||
Finish
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</Sheet>
|
||||
</Modal>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
@ -1,8 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import { useRouter } from 'next/navigation'
|
||||
import type { ProFormInstance } from '@ant-design/pro-components'
|
||||
import React, { useState, useRef, useEffect } from 'react'
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { message } from 'antd'
|
||||
import {
|
||||
Modal,
|
||||
@ -31,16 +30,33 @@ const stepsOfAddingSpace = [
|
||||
'Choose a Datasource type',
|
||||
'Setup the Datasource'
|
||||
]
|
||||
const documentTypeList = [
|
||||
{
|
||||
type: 'text',
|
||||
title: 'Text',
|
||||
subTitle: 'Paste some text'
|
||||
},
|
||||
{
|
||||
type: 'webPage',
|
||||
title: 'Web Page',
|
||||
subTitle: 'Crawl text from a web page'
|
||||
},
|
||||
{
|
||||
type: 'file',
|
||||
title: 'File',
|
||||
subTitle: 'It can be: PDF, CSV, JSON, Text, PowerPoint, Word, Excel'
|
||||
}
|
||||
]
|
||||
|
||||
const Index = () => {
|
||||
const router = useRouter()
|
||||
const formRef = useRef<ProFormInstance>()
|
||||
const [activeStep, setActiveStep] = useState<number>(0)
|
||||
const [knowledgeSpaceList, setKnowledgeSpaceList] = useState<any>([])
|
||||
const [isAddKnowledgeSpaceModalShow, setIsAddKnowledgeSpaceModalShow] =
|
||||
useState<boolean>(false)
|
||||
const [knowledgeSpaceName, setKnowledgeSpaceName] = useState<string>('')
|
||||
const [webPageUrl, setWebPageUrl] = useState<string>('')
|
||||
const [documentName, setDocumentName] = useState<string>('')
|
||||
useEffect(() => {
|
||||
async function fetchData() {
|
||||
const res = await fetch('http://localhost:8000/knowledge/space/list', {
|
||||
@ -59,15 +75,21 @@ const Index = () => {
|
||||
}, [])
|
||||
return (
|
||||
<>
|
||||
<div className="page-header p-4">
|
||||
<div className="page-header-title">Knowledge Spaces</div>
|
||||
<Sheet sx={{
|
||||
display: "flex",
|
||||
justifyContent: "space-between"
|
||||
}} className="p-4">
|
||||
<Sheet sx={{
|
||||
fontSize: '30px',
|
||||
fontWeight: 'bold'
|
||||
}}>Knowledge Spaces</Sheet>
|
||||
<Button
|
||||
onClick={() => setIsAddKnowledgeSpaceModalShow(true)}
|
||||
variant="outlined"
|
||||
>
|
||||
+ New Knowledge Space
|
||||
</Button>
|
||||
</div>
|
||||
</Sheet>
|
||||
<div className="page-body p-4">
|
||||
<Table sx={{ '& thead th:nth-child(1)': { width: '40%' } }}>
|
||||
<thead>
|
||||
@ -100,7 +122,6 @@ const Index = () => {
|
||||
</Table>
|
||||
</div>
|
||||
<Modal
|
||||
title="Add Knowledge Space"
|
||||
sx={{
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
@ -191,14 +212,42 @@ const Index = () => {
|
||||
) : activeStep === 1 ? (
|
||||
<>
|
||||
<Box sx={{ margin: '30px auto' }}>
|
||||
<Button variant="outlined" onClick={() => setActiveStep(2)}>
|
||||
Web Page
|
||||
</Button>
|
||||
{documentTypeList.map((item: any) => (
|
||||
<Sheet
|
||||
key={item.type}
|
||||
sx={{
|
||||
boxSizing: 'border-box',
|
||||
height: '80px',
|
||||
padding: '12px',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
justifyContent: 'space-between',
|
||||
border: '1px solid gray',
|
||||
borderRadius: '6px',
|
||||
marginBottom: '20px',
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
onClick={() => {
|
||||
if (item.type === 'webPage') {
|
||||
setActiveStep(2);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Sheet sx={{ fontSize: '20px', fontWeight: 'bold' }}>{item.title}</Sheet>
|
||||
<Sheet>{item.subTitle}</Sheet>
|
||||
</Sheet>
|
||||
))}
|
||||
</Box>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Box sx={{ margin: '30px auto' }}>
|
||||
Name:
|
||||
<Input
|
||||
placeholder="Please input the name"
|
||||
onChange={(e: any) => setDocumentName(e.target.value)}
|
||||
sx={{ marginBottom: '20px'}}
|
||||
/>
|
||||
Web Page URL:
|
||||
<Input
|
||||
placeholder="Please input the Web Page URL"
|
||||
@ -207,6 +256,14 @@ const Index = () => {
|
||||
</Box>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
if (documentName === '') {
|
||||
message.error('Please input the name');
|
||||
return;
|
||||
}
|
||||
if (webPageUrl === '') {
|
||||
message.error('Please input the Web Page URL');
|
||||
return;
|
||||
}
|
||||
const res = await fetch(
|
||||
`http://localhost:8000/knowledge/${knowledgeSpaceName}/document/add`,
|
||||
{
|
||||
@ -215,7 +272,8 @@ const Index = () => {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
doc_name: webPageUrl,
|
||||
doc_name: documentName,
|
||||
content: webPageUrl,
|
||||
doc_type: 'URL'
|
||||
})
|
||||
}
|
||||
@ -235,24 +293,6 @@ const Index = () => {
|
||||
)}
|
||||
</Sheet>
|
||||
</Modal>
|
||||
<style jsx>{`
|
||||
.page-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
.page-header-title {
|
||||
font-size: 30px;
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
.datasource-type-wrap {
|
||||
height: 100px;
|
||||
line-height: 100px;
|
||||
border: 1px solid black;
|
||||
border-radius: 20px;
|
||||
margin-bottom: 20px;
|
||||
cursor: pointer;
|
||||
}
|
||||
`}</style>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ class MyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, set):
|
||||
return list(obj)
|
||||
elif hasattr(obj, '__dict__'):
|
||||
elif hasattr(obj, "__dict__"):
|
||||
return obj.__dict__
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
@ -78,6 +78,7 @@ def load_native_plugins(cfg: Config):
|
||||
if not cfg.plugins_auto_load:
|
||||
print("not auto load_native_plugins")
|
||||
return
|
||||
|
||||
def load_from_git(cfg: Config):
|
||||
print("async load_native_plugins")
|
||||
branch_name = cfg.plugins_git_branch
|
||||
@ -85,16 +86,20 @@ def load_native_plugins(cfg: Config):
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
try:
|
||||
session = requests.Session()
|
||||
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={'Authorization': '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'})
|
||||
response = session.get(
|
||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={"Authorization": "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"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
@ -110,7 +115,6 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
|
@ -8,6 +8,7 @@ class SeparatorStyle(Enum):
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
|
||||
class ExampleType(Enum):
|
||||
ONE_SHOT = "one_shot"
|
||||
FEW_SHOT = "few_shot"
|
||||
|
@ -92,7 +92,7 @@ class Config(metaclass=Singleton):
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
from test_cls_1 import TestBase,Test1
|
||||
from test_cls_1 import TestBase, Test1
|
||||
from test_cls_2 import Test2
|
||||
|
||||
CFG = Config()
|
||||
@ -60,21 +60,21 @@ CFG = Config()
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test1 = Test1()
|
||||
@ -83,4 +83,4 @@ if __name__ == "__main__":
|
||||
test1.test()
|
||||
test2.write()
|
||||
test1.test()
|
||||
test2.test()
|
||||
test2.test()
|
||||
|
@ -4,9 +4,9 @@ from test_cls_base import TestBase
|
||||
|
||||
|
||||
class Test1(TestBase):
|
||||
mode:str = "456"
|
||||
mode: str = "456"
|
||||
|
||||
def write(self):
|
||||
self.test_values.append("x")
|
||||
self.test_values.append("y")
|
||||
self.test_values.append("g")
|
||||
|
||||
|
@ -3,13 +3,15 @@ from pydantic import BaseModel
|
||||
from test_cls_base import TestBase
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
|
||||
class Test2(TestBase):
|
||||
test_2_values:List = []
|
||||
mode:str = "789"
|
||||
test_2_values: List = []
|
||||
mode: str = "789"
|
||||
|
||||
def write(self):
|
||||
self.test_values.append(1)
|
||||
self.test_values.append(2)
|
||||
self.test_values.append(3)
|
||||
self.test_2_values.append("x")
|
||||
self.test_2_values.append("y")
|
||||
self.test_2_values.append("z")
|
||||
self.test_2_values.append("z")
|
||||
|
@ -5,9 +5,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
class TestBase(BaseModel, ABC):
|
||||
test_values: List = []
|
||||
mode:str = "123"
|
||||
mode: str = "123"
|
||||
|
||||
def test(self):
|
||||
print(self.__class__.__name__ + ":" )
|
||||
print(self.__class__.__name__ + ":")
|
||||
print(self.test_values)
|
||||
print(self.mode)
|
||||
print(self.mode)
|
||||
|
@ -39,7 +39,9 @@ class KnowledgeEmbedding:
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
||||
return get_knowledge_embedding(
|
||||
self.knowledge_type, self.knowledge_source, self.vector_store_config
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(
|
||||
@ -56,3 +58,9 @@ class KnowledgeEmbedding:
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.vector_name_exists()
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
vector_client.delete_by_ids(ids=ids)
|
||||
|
@ -33,8 +33,6 @@ class BaseChatHistoryMemory(ABC):
|
||||
def clear(self) -> None:
|
||||
"""Clear session memory from the local file"""
|
||||
|
||||
|
||||
def conv_list(self, user_name:str=None) -> None:
|
||||
def conv_list(self, user_name: str = None) -> None:
|
||||
"""get user's conversation list"""
|
||||
pass
|
||||
|
||||
|
@ -11,7 +11,7 @@ from pilot.scene.message import (
|
||||
conversation_from_dict,
|
||||
conversations_to_dict,
|
||||
)
|
||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||
from pilot.common.custom_data_structure import FixedSizeDict, FixedSizeList
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -19,7 +19,6 @@ CFG = Config()
|
||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||
histroies_map = FixedSizeDict(100)
|
||||
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
self.histroies_map.update({chat_session_id: []})
|
||||
|
2
pilot/model/cache/__init__.py
vendored
2
pilot/model/cache/__init__.py
vendored
@ -1,4 +1,4 @@
|
||||
from .base import Cache
|
||||
from .disk_cache import DiskCache
|
||||
from .memory_cache import InMemoryCache
|
||||
from .gpt_cache import GPTCache
|
||||
from .gpt_cache import GPTCache
|
||||
|
4
pilot/model/cache/base.py
vendored
4
pilot/model/cache/base.py
vendored
@ -3,8 +3,8 @@ import hashlib
|
||||
from typing import Any, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Cache(ABC):
|
||||
|
||||
def create(self, key: str) -> bool:
|
||||
pass
|
||||
|
||||
@ -24,4 +24,4 @@ class Cache(ABC):
|
||||
@abstractmethod
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""see if we can return a cached value for the passed key"""
|
||||
pass
|
||||
pass
|
||||
|
12
pilot/model/cache/disk_cache.py
vendored
12
pilot/model/cache/disk_cache.py
vendored
@ -3,15 +3,15 @@ import diskcache
|
||||
import platformdirs
|
||||
from pilot.model.cache import Cache
|
||||
|
||||
|
||||
class DiskCache(Cache):
|
||||
"""DiskCache is a cache that uses diskcache lib.
|
||||
https://github.com/grantjenks/python-diskcache
|
||||
https://github.com/grantjenks/python-diskcache
|
||||
"""
|
||||
|
||||
def __init__(self, llm_name: str):
|
||||
self._diskcache = diskcache.Cache(
|
||||
os.path.join(
|
||||
platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache"
|
||||
)
|
||||
os.path.join(platformdirs.user_cache_dir("dbgpt"), f"_{llm_name}.diskcache")
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
@ -22,6 +22,6 @@ class DiskCache(Cache):
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._diskcache
|
||||
|
||||
|
||||
def clear(self):
|
||||
self._diskcache.clear()
|
||||
self._diskcache.clear()
|
||||
|
11
pilot/model/cache/gpt_cache.py
vendored
11
pilot/model/cache/gpt_cache.py
vendored
@ -9,22 +9,23 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class GPTCache(Cache):
|
||||
|
||||
"""
|
||||
GPTCache is a semantic cache that uses
|
||||
"""
|
||||
GPTCache is a semantic cache that uses
|
||||
"""
|
||||
|
||||
def __init__(self, cache) -> None:
|
||||
"""GPT Cache is a semantic cache that uses GPTCache lib."""
|
||||
|
||||
|
||||
if isinstance(cache, str):
|
||||
_cache = Cache()
|
||||
init_similar_cache(
|
||||
data_dir=os.path.join(
|
||||
platformdirs.user_cache_dir("dbgpt"), f"_{cache}.gptcache"
|
||||
),
|
||||
cache_obj=_cache
|
||||
cache_obj=_cache,
|
||||
)
|
||||
else:
|
||||
_cache = cache
|
||||
@ -41,4 +42,4 @@ class GPTCache(Cache):
|
||||
return get(key) is not None
|
||||
|
||||
def create(self, llm: str, **kwargs: Dict[str, Any]) -> str:
|
||||
pass
|
||||
pass
|
||||
|
11
pilot/model/cache/memory_cache.py
vendored
11
pilot/model/cache/memory_cache.py
vendored
@ -1,24 +1,23 @@
|
||||
from typing import Dict, Any
|
||||
from pilot.model.cache import Cache
|
||||
|
||||
|
||||
class InMemoryCache(Cache):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"Initialize that stores things in memory."
|
||||
self._cache: Dict[str, Any] = {}
|
||||
|
||||
def create(self, key: str) -> bool:
|
||||
pass
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
return self._cache.clear()
|
||||
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
self._cache[key] = value
|
||||
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return self._cache[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self._cache.get(key, None) is not None
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self._cache.get(key, None) is not None
|
||||
|
@ -2,7 +2,7 @@ import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -16,12 +16,13 @@ from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageV
|
||||
from pilot.configs.config import Config
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import (LOGDIR)
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
from pilot.scene.base_message import (BaseMessage)
|
||||
from pilot.scene.base_message import BaseMessage
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
@ -42,19 +43,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
|
||||
|
||||
def __get_conv_user_message(conversations: dict):
|
||||
messages = conversations['messages']
|
||||
messages = conversations["messages"]
|
||||
for item in messages:
|
||||
if item['type'] == "human":
|
||||
return item['data']['content']
|
||||
if item["type"] == "human":
|
||||
return item["data"]["content"]
|
||||
return ""
|
||||
|
||||
|
||||
@router.get('/v1/chat/dialogue/list', response_model=Result[ConversationVo])
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list(response: Response, user_id: str = None):
|
||||
# 设置CORS头部信息
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'GET'
|
||||
response.headers['Access-Control-Request-Headers'] = 'content-type'
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET"
|
||||
response.headers["Access-Control-Request-Headers"] = "content-type"
|
||||
|
||||
dialogues: List = []
|
||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||
@ -65,26 +66,44 @@ async def dialogue_list(response: Response, user_id: str = None):
|
||||
conversations = json.loads(messages)
|
||||
|
||||
first_conv: OnceConversation = conversations[0]
|
||||
conv_vo: ConversationVo = ConversationVo(conv_uid=conv_uid, user_input=__get_conv_user_message(first_conv),
|
||||
chat_mode=first_conv['chat_mode'])
|
||||
conv_vo: ConversationVo = ConversationVo(
|
||||
conv_uid=conv_uid,
|
||||
user_input=__get_conv_user_message(first_conv),
|
||||
chat_mode=first_conv["chat_mode"],
|
||||
)
|
||||
dialogues.append(conv_vo)
|
||||
|
||||
return Result[ConversationVo].succ(dialogues)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]])
|
||||
@router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]])
|
||||
async def dialogue_scenes():
|
||||
scene_vos: List[ChatSceneVo] = []
|
||||
new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution]
|
||||
new_modes: List[ChatScene] = [
|
||||
ChatScene.ChatDb,
|
||||
ChatScene.ChatData,
|
||||
ChatScene.ChatDashboard,
|
||||
ChatScene.ChatKnowledge,
|
||||
ChatScene.ChatExecution,
|
||||
]
|
||||
for scene in new_modes:
|
||||
if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]:
|
||||
scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param")
|
||||
if not scene.value in [
|
||||
ChatScene.ChatNormal.value,
|
||||
ChatScene.InnerChatDBSummary.value,
|
||||
]:
|
||||
scene_vo = ChatSceneVo(
|
||||
chat_scene=scene.value,
|
||||
scene_name=scene.name,
|
||||
param_title="Selection Param",
|
||||
)
|
||||
scene_vos.append(scene_vo)
|
||||
return Result.succ(scene_vos)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/new', response_model=Result[ConversationVo])
|
||||
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None):
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||
|
||||
@ -92,7 +111,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str
|
||||
def get_db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
params:dict = {}
|
||||
params: dict = {}
|
||||
for name in dbs:
|
||||
params.update({name: name})
|
||||
return params
|
||||
@ -106,11 +125,16 @@ def plugins_select_info():
|
||||
|
||||
|
||||
def knowledge_list():
|
||||
"""return knowledge space list"""
|
||||
params: dict = {}
|
||||
request = KnowledgeSpaceRequest()
|
||||
return knowledge_service.get_knowledge_space(request)
|
||||
spaces = knowledge_service.get_knowledge_space(request)
|
||||
for space in spaces:
|
||||
params.update({space.name: space.name})
|
||||
return params
|
||||
|
||||
|
||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||
@router.post("/v1/chat/mode/params/list", response_model=Result[dict])
|
||||
async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
if ChatScene.ChatWithDbQA.value == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
@ -126,14 +150,14 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value):
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/delete')
|
||||
@router.post("/v1/chat/dialogue/delete")
|
||||
async def dialogue_delete(con_uid: str):
|
||||
history_mem = DuckdbHistoryMemory(con_uid)
|
||||
history_mem.delete()
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.get('/v1/chat/dialogue/messages/history', response_model=Result[MessageVo])
|
||||
@router.get("/v1/chat/dialogue/messages/history", response_model=Result[MessageVo])
|
||||
async def dialogue_history_messages(con_uid: str):
|
||||
print(f"dialogue_history_messages:{con_uid}")
|
||||
message_vos: List[MessageVo] = []
|
||||
@ -142,12 +166,14 @@ async def dialogue_history_messages(con_uid: str):
|
||||
history_messages: List[OnceConversation] = history_mem.get_messages()
|
||||
if history_messages:
|
||||
for once in history_messages:
|
||||
once_message_vos = [message2Vo(element, once['chat_order']) for element in once['messages']]
|
||||
once_message_vos = [
|
||||
message2Vo(element, once["chat_order"]) for element in once["messages"]
|
||||
]
|
||||
message_vos.extend(once_message_vos)
|
||||
return Result.succ(message_vos)
|
||||
|
||||
|
||||
@router.post('/v1/chat/completions')
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
global model_semaphore, global_counter
|
||||
@ -157,7 +183,9 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
await model_semaphore.acquire()
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||
raise StopAsyncIteration(
|
||||
Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")
|
||||
)
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Result(Generic[T], BaseModel):
|
||||
@ -24,15 +24,17 @@ class Result(Generic[T], BaseModel):
|
||||
|
||||
|
||||
class ChatSceneVo(BaseModel):
|
||||
chat_scene: str = Field(..., description="chat_scene")
|
||||
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||
chat_scene: str = Field(..., description="chat_scene")
|
||||
scene_name: str = Field(..., description="chat_scene name show for user")
|
||||
param_title: str = Field(..., description="chat_scene required parameter title")
|
||||
|
||||
|
||||
class ConversationVo(BaseModel):
|
||||
"""
|
||||
dialogue_uid
|
||||
"""
|
||||
conv_uid: str = Field(..., description="dialogue uid")
|
||||
|
||||
conv_uid: str = Field(..., description="dialogue uid")
|
||||
"""
|
||||
user input
|
||||
"""
|
||||
@ -44,7 +46,7 @@ class ConversationVo(BaseModel):
|
||||
"""
|
||||
the scene of chat
|
||||
"""
|
||||
chat_mode: str = Field(..., description="the scene of chat ")
|
||||
chat_mode: str = Field(..., description="the scene of chat ")
|
||||
|
||||
"""
|
||||
chat scene select param
|
||||
@ -52,11 +54,11 @@ class ConversationVo(BaseModel):
|
||||
select_param: str = None
|
||||
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
"""
|
||||
role that sends out the current message
|
||||
role that sends out the current message
|
||||
"""
|
||||
|
||||
role: str
|
||||
"""
|
||||
current message
|
||||
@ -70,4 +72,3 @@ class MessageVo(BaseModel):
|
||||
time the current message was sent
|
||||
"""
|
||||
time_stamp: Any = None
|
||||
|
||||
|
@ -10,8 +10,10 @@ from pilot.configs.config import Config
|
||||
CFG = Config()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class DocumentChunkEntity(Base):
|
||||
__tablename__ = 'document_chunk'
|
||||
__tablename__ = "document_chunk"
|
||||
id = Column(Integer, primary_key=True)
|
||||
document_id = Column(Integer)
|
||||
doc_name = Column(String(100))
|
||||
@ -29,43 +31,55 @@ class DocumentChunkDao:
|
||||
def __init__(self):
|
||||
database = "knowledge_management"
|
||||
self.db_engine = create_engine(
|
||||
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
|
||||
echo=True)
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.db_engine)
|
||||
|
||||
def create_documents_chunks(self, documents:List):
|
||||
def create_documents_chunks(self, documents: List):
|
||||
session = self.Session()
|
||||
docs = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
document_id=document.document_id,
|
||||
content=document.content or "",
|
||||
meta_info=document.meta_info or "",
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
doc_name=document.doc_name,
|
||||
doc_type=document.doc_type,
|
||||
document_id=document.document_id,
|
||||
content=document.content or "",
|
||||
meta_info=document.meta_info or "",
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for document in documents]
|
||||
for document in documents
|
||||
]
|
||||
session.add_all(docs)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_document_chunks(self, query:DocumentChunkEntity, page=1, page_size=20):
|
||||
def get_document_chunks(self, query: DocumentChunkEntity, page=1, page_size=20):
|
||||
session = self.Session()
|
||||
document_chunks = session.query(DocumentChunkEntity)
|
||||
if query.id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.id == query.id)
|
||||
if query.document_id is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.document_id == query.document_id)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.document_id == query.document_id
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_type == query.doc_type)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.doc_name == query.doc_name)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.meta_info is not None:
|
||||
document_chunks = document_chunks.filter(DocumentChunkEntity.meta_info == query.meta_info)
|
||||
document_chunks = document_chunks.filter(
|
||||
DocumentChunkEntity.meta_info == query.meta_info
|
||||
)
|
||||
|
||||
document_chunks = document_chunks.order_by(DocumentChunkEntity.id.desc())
|
||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(page_size)
|
||||
document_chunks = document_chunks.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
result = document_chunks.all()
|
||||
return result
|
||||
|
||||
|
@ -13,7 +13,11 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest,
|
||||
KnowledgeQueryResponse,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentSyncRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
)
|
||||
|
||||
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||
@ -62,16 +66,15 @@ def document_add(space_name: str, request: KnowledgeDocumentRequest):
|
||||
def document_list(space_name: str, query_request: DocumentQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.get_knowledge_documents(
|
||||
space_name,
|
||||
query_request
|
||||
))
|
||||
return Result.succ(
|
||||
knowledge_space_service.get_knowledge_documents(space_name, query_request)
|
||||
)
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document list error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/upload")
|
||||
def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||
async def document_sync(space_name: str, file: UploadFile = File(...)):
|
||||
print(f"/document/upload params: {space_name}")
|
||||
try:
|
||||
with NamedTemporaryFile(delete=False) as tmp:
|
||||
@ -92,7 +95,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
knowledge_space_service.sync_knowledge_document(
|
||||
space_name=space_name, doc_ids=request.doc_ids
|
||||
)
|
||||
Result.succ([])
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document sync error {e}")
|
||||
|
||||
@ -101,9 +104,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
def document_list(space_name: str, query_request: ChunkQueryRequest):
|
||||
print(f"/document/list params: {space_name}, {query_request}")
|
||||
try:
|
||||
return Result.succ(knowledge_space_service.get_document_chunks(
|
||||
query_request
|
||||
))
|
||||
return Result.succ(knowledge_space_service.get_document_chunks(query_request))
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document chunk list error {e}")
|
||||
|
||||
|
@ -12,7 +12,7 @@ Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeDocumentEntity(Base):
|
||||
__tablename__ = 'knowledge_document'
|
||||
__tablename__ = "knowledge_document"
|
||||
id = Column(Integer, primary_key=True)
|
||||
doc_name = Column(String(100))
|
||||
doc_type = Column(String(100))
|
||||
@ -21,23 +21,25 @@ class KnowledgeDocumentEntity(Base):
|
||||
status = Column(String(100))
|
||||
last_sync = Column(String(100))
|
||||
content = Column(Text)
|
||||
result = Column(Text)
|
||||
vector_ids = Column(Text)
|
||||
gmt_created = Column(DateTime)
|
||||
gmt_modified = Column(DateTime)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
|
||||
|
||||
|
||||
class KnowledgeDocumentDao:
|
||||
def __init__(self):
|
||||
database = "knowledge_management"
|
||||
self.db_engine = create_engine(
|
||||
f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}',
|
||||
echo=True)
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.db_engine)
|
||||
|
||||
def create_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||
def create_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
knowledge_document = KnowledgeDocumentEntity(
|
||||
doc_name=document.doc_name,
|
||||
@ -47,9 +49,10 @@ class KnowledgeDocumentDao:
|
||||
status=document.status,
|
||||
last_sync=document.last_sync,
|
||||
content=document.content or "",
|
||||
result=document.result or "",
|
||||
vector_ids=document.vector_ids,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_document)
|
||||
session.commit()
|
||||
@ -60,28 +63,42 @@ class KnowledgeDocumentDao:
|
||||
session = self.Session()
|
||||
knowledge_documents = session.query(KnowledgeDocumentEntity)
|
||||
if query.id is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.id == query.id)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.id == query.id
|
||||
)
|
||||
if query.doc_name is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_name == query.doc_name)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_name == query.doc_name
|
||||
)
|
||||
if query.doc_type is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.doc_type == query.doc_type)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.doc_type == query.doc_type
|
||||
)
|
||||
if query.space is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.space == query.space)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.space == query.space
|
||||
)
|
||||
if query.status is not None:
|
||||
knowledge_documents = knowledge_documents.filter(KnowledgeDocumentEntity.status == query.status)
|
||||
knowledge_documents = knowledge_documents.filter(
|
||||
KnowledgeDocumentEntity.status == query.status
|
||||
)
|
||||
|
||||
knowledge_documents = knowledge_documents.order_by(KnowledgeDocumentEntity.id.desc())
|
||||
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(page_size)
|
||||
knowledge_documents = knowledge_documents.order_by(
|
||||
KnowledgeDocumentEntity.id.desc()
|
||||
)
|
||||
knowledge_documents = knowledge_documents.offset((page - 1) * page_size).limit(
|
||||
page_size
|
||||
)
|
||||
result = knowledge_documents.all()
|
||||
return result
|
||||
|
||||
def update_knowledge_document(self, document:KnowledgeDocumentEntity):
|
||||
def update_knowledge_document(self, document: KnowledgeDocumentEntity):
|
||||
session = self.Session()
|
||||
updated_space = session.merge(document)
|
||||
session.commit()
|
||||
return updated_space.id
|
||||
|
||||
def delete_knowledge_document(self, document_id:int):
|
||||
def delete_knowledge_document(self, document_id: int):
|
||||
cursor = self.conn.cursor()
|
||||
query = "DELETE FROM knowledge_document WHERE id = %s"
|
||||
cursor.execute(query, (document_id,))
|
||||
|
@ -4,17 +4,24 @@ from datetime import datetime
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
from pilot.logs import logger
|
||||
from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao
|
||||
from pilot.openapi.knowledge.document_chunk_dao import (
|
||||
DocumentChunkEntity,
|
||||
DocumentChunkDao,
|
||||
)
|
||||
from pilot.openapi.knowledge.knowledge_document_dao import (
|
||||
KnowledgeDocumentDao,
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
from pilot.openapi.knowledge.knowledge_space_dao import (
|
||||
KnowledgeSpaceDao,
|
||||
KnowledgeSpaceEntity,
|
||||
)
|
||||
from pilot.openapi.knowledge.request.knowledge_request import (
|
||||
KnowledgeSpaceRequest,
|
||||
KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentQueryRequest,
|
||||
ChunkQueryRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
@ -23,7 +30,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
document_chunk_dao = DocumentChunkDao()
|
||||
|
||||
CFG=Config()
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
@ -53,10 +60,7 @@ class KnowledgeService:
|
||||
"""create knowledge document"""
|
||||
|
||||
def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
space=space
|
||||
)
|
||||
query = KnowledgeDocumentEntity(doc_name=request.doc_name, space=space)
|
||||
documents = knowledge_document_dao.get_knowledge_documents(query)
|
||||
if len(documents) > 0:
|
||||
raise Exception(f"document name:{request.doc_name} have already named")
|
||||
@ -74,26 +78,27 @@ class KnowledgeService:
|
||||
|
||||
"""get knowledge space"""
|
||||
|
||||
def get_knowledge_space(self, request:KnowledgeSpaceRequest):
|
||||
def get_knowledge_space(self, request: KnowledgeSpaceRequest):
|
||||
query = KnowledgeSpaceEntity(
|
||||
name=request.name,
|
||||
vector_type=request.vector_type,
|
||||
owner=request.owner
|
||||
name=request.name, vector_type=request.vector_type, owner=request.owner
|
||||
)
|
||||
return knowledge_space_dao.get_knowledge_space(query)
|
||||
|
||||
"""get knowledge get_knowledge_documents"""
|
||||
|
||||
def get_knowledge_documents(self, space, request:DocumentQueryRequest):
|
||||
def get_knowledge_documents(self, space, request: DocumentQueryRequest):
|
||||
query = KnowledgeDocumentEntity(
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type,
|
||||
space=space,
|
||||
status=request.status,
|
||||
)
|
||||
return knowledge_document_dao.get_knowledge_documents(query, page=request.page, page_size=request.page_size)
|
||||
return knowledge_document_dao.get_knowledge_documents(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
|
||||
def sync_knowledge_document(self, space_name, doc_ids):
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(
|
||||
@ -101,12 +106,14 @@ class KnowledgeService:
|
||||
space=space_name,
|
||||
)
|
||||
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
|
||||
client = KnowledgeEmbedding(knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
})
|
||||
client = KnowledgeEmbedding(
|
||||
knowledge_source=doc.content,
|
||||
knowledge_type=doc.doc_type.upper(),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config={
|
||||
"vector_store_name": space_name,
|
||||
},
|
||||
)
|
||||
chunk_docs = client.read()
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
@ -114,9 +121,12 @@ class KnowledgeService:
|
||||
doc.gmt_modified = datetime.now()
|
||||
knowledge_document_dao.update_knowledge_document(doc)
|
||||
# async doc embeddings
|
||||
thread = threading.Thread(target=self.async_doc_embedding(client, chunk_docs, doc))
|
||||
thread = threading.Thread(
|
||||
target=self.async_doc_embedding, args=(client, chunk_docs, doc)
|
||||
)
|
||||
thread.start()
|
||||
#save chunk details
|
||||
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
|
||||
# save chunk details
|
||||
chunk_entities = [
|
||||
DocumentChunkEntity(
|
||||
doc_name=doc.doc_name,
|
||||
@ -125,9 +135,10 @@ class KnowledgeService:
|
||||
content=chunk_doc.page_content,
|
||||
meta_info=str(chunk_doc.metadata),
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
for chunk_doc in chunk_docs]
|
||||
for chunk_doc in chunk_docs
|
||||
]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
|
||||
return True
|
||||
@ -145,26 +156,30 @@ class KnowledgeService:
|
||||
return knowledge_space_dao.delete_knowledge_space(space_id)
|
||||
|
||||
"""get document chunks"""
|
||||
def get_document_chunks(self, request:ChunkQueryRequest):
|
||||
|
||||
def get_document_chunks(self, request: ChunkQueryRequest):
|
||||
query = DocumentChunkEntity(
|
||||
id=request.id,
|
||||
document_id=request.document_id,
|
||||
doc_name=request.doc_name,
|
||||
doc_type=request.doc_type
|
||||
doc_type=request.doc_type,
|
||||
)
|
||||
return document_chunk_dao.get_document_chunks(
|
||||
query, page=request.page, page_size=request.page_size
|
||||
)
|
||||
return document_chunk_dao.get_document_chunks(query, page=request.page, page_size=request.page_size)
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
logger.info(f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}")
|
||||
logger.info(
|
||||
f"async_doc_embedding, doc:{doc.doc_name}, chunk_size:{len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
|
||||
)
|
||||
try:
|
||||
vector_ids = client.knowledge_embedding_batch(chunk_docs)
|
||||
doc.status = SyncStatus.FINISHED.name
|
||||
doc.content = "embedding success"
|
||||
doc.result = "document embedding success"
|
||||
doc.vector_ids = ",".join(vector_ids)
|
||||
logger.info(f"async document embedding, success:{doc.doc_name}")
|
||||
except Exception as e:
|
||||
doc.status = SyncStatus.FAILED.name
|
||||
doc.content = str(e)
|
||||
|
||||
doc.result = "document embedding failed" + str(e)
|
||||
logger.error(f"document embedding, failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
|
||||
|
@ -10,8 +10,10 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
CFG = Config()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeSpaceEntity(Base):
|
||||
__tablename__ = 'knowledge_space'
|
||||
__tablename__ = "knowledge_space"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100))
|
||||
vector_type = Column(String(100))
|
||||
@ -27,10 +29,13 @@ class KnowledgeSpaceEntity(Base):
|
||||
class KnowledgeSpaceDao:
|
||||
def __init__(self):
|
||||
database = "knowledge_management"
|
||||
self.db_engine = create_engine(f'mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}', echo=True)
|
||||
self.db_engine = create_engine(
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.db_engine)
|
||||
|
||||
def create_knowledge_space(self, space:KnowledgeSpaceRequest):
|
||||
def create_knowledge_space(self, space: KnowledgeSpaceRequest):
|
||||
session = self.Session()
|
||||
knowledge_space = KnowledgeSpaceEntity(
|
||||
name=space.name,
|
||||
@ -38,43 +43,61 @@ class KnowledgeSpaceDao:
|
||||
desc=space.desc,
|
||||
owner=space.owner,
|
||||
gmt_created=datetime.now(),
|
||||
gmt_modified=datetime.now()
|
||||
gmt_modified=datetime.now(),
|
||||
)
|
||||
session.add(knowledge_space)
|
||||
session.commit()
|
||||
|
||||
session.close()
|
||||
|
||||
def get_knowledge_space(self, query:KnowledgeSpaceEntity):
|
||||
def get_knowledge_space(self, query: KnowledgeSpaceEntity):
|
||||
session = self.Session()
|
||||
knowledge_spaces = session.query(KnowledgeSpaceEntity)
|
||||
if query.id is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.id == query.id)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.id == query.id
|
||||
)
|
||||
if query.name is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.name == query.name)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.name == query.name
|
||||
)
|
||||
if query.vector_type is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.vector_type == query.vector_type)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.vector_type == query.vector_type
|
||||
)
|
||||
if query.desc is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.desc == query.desc)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.desc == query.desc
|
||||
)
|
||||
if query.owner is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.owner == query.owner)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.owner == query.owner
|
||||
)
|
||||
if query.gmt_created is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_created == query.gmt_created)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_created == query.gmt_created
|
||||
)
|
||||
if query.gmt_modified is not None:
|
||||
knowledge_spaces = knowledge_spaces.filter(KnowledgeSpaceEntity.gmt_modified == query.gmt_modified)
|
||||
knowledge_spaces = knowledge_spaces.filter(
|
||||
KnowledgeSpaceEntity.gmt_modified == query.gmt_modified
|
||||
)
|
||||
|
||||
knowledge_spaces = knowledge_spaces.order_by(KnowledgeSpaceEntity.gmt_created.desc())
|
||||
knowledge_spaces = knowledge_spaces.order_by(
|
||||
KnowledgeSpaceEntity.gmt_created.desc()
|
||||
)
|
||||
result = knowledge_spaces.all()
|
||||
return result
|
||||
|
||||
def update_knowledge_space(self, space_id:int, space:KnowledgeSpaceEntity):
|
||||
def update_knowledge_space(self, space_id: int, space: KnowledgeSpaceEntity):
|
||||
cursor = self.conn.cursor()
|
||||
query = "UPDATE knowledge_space SET name = %s, vector_type = %s, desc = %s, owner = %s WHERE id = %s"
|
||||
cursor.execute(query, (space.name, space.vector_type, space.desc, space.owner, space_id))
|
||||
cursor.execute(
|
||||
query, (space.name, space.vector_type, space.desc, space.owner, space_id)
|
||||
)
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def delete_knowledge_space(self, space_id:int):
|
||||
def delete_knowledge_space(self, space_id: int):
|
||||
cursor = self.conn.cursor()
|
||||
query = "DELETE FROM knowledge_space WHERE id = %s"
|
||||
cursor.execute(query, (space_id,))
|
||||
|
@ -30,17 +30,19 @@ class KnowledgeDocumentRequest(BaseModel):
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str
|
||||
"""content: content"""
|
||||
content: str
|
||||
content: str = None
|
||||
"""text_chunk_size: text_chunk_size"""
|
||||
# text_chunk_size: int
|
||||
|
||||
|
||||
class DocumentQueryRequest(BaseModel):
|
||||
"""doc_name: doc path"""
|
||||
|
||||
doc_name: str = None
|
||||
"""doc_type: doc type"""
|
||||
doc_type: str= None
|
||||
doc_type: str = None
|
||||
"""status: status"""
|
||||
status: str= None
|
||||
status: str = None
|
||||
"""page: page"""
|
||||
page: int = 1
|
||||
"""page_size: page size"""
|
||||
@ -49,10 +51,13 @@ class DocumentQueryRequest(BaseModel):
|
||||
|
||||
class DocumentSyncRequest(BaseModel):
|
||||
"""doc_ids: doc ids"""
|
||||
|
||||
doc_ids: List
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
"""id: id"""
|
||||
|
||||
id: int = None
|
||||
"""document_id: doc id"""
|
||||
document_id: int = None
|
||||
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
from pilot.common.schema import ExampleType
|
||||
|
||||
|
||||
class ExampleSelector(BaseModel, ABC):
|
||||
examples: List[List]
|
||||
use_example: bool = False
|
||||
|
@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
try:
|
||||
|
@ -1,11 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Scene:
|
||||
def __init__(self, code, describe, is_inner):
|
||||
self.code = code
|
||||
self.describe = describe
|
||||
self.is_inner = is_inner
|
||||
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDbExecute = "chat_with_db_execute"
|
||||
ChatWithDbQA = "chat_with_db_qa"
|
||||
@ -24,4 +26,3 @@ class ChatScene(Enum):
|
||||
@staticmethod
|
||||
def is_valid_mode(mode):
|
||||
return any(mode == item.value for item in ChatScene)
|
||||
|
||||
|
@ -12,7 +12,10 @@ from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ChartData, ReportData
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
ChartData,
|
||||
ReportData,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -22,9 +25,7 @@ class ChatDashboard(BaseChat):
|
||||
report_name: str
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input, report_name
|
||||
):
|
||||
def __init__(self, chat_session_id, db_name, user_input, report_name):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
@ -51,7 +52,7 @@ class ChatDashboard(BaseChat):
|
||||
"input": self.current_user_input,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(self.db_connect),
|
||||
"supported_chat_type": "" #TODO
|
||||
"supported_chat_type": "" # TODO
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
return input_values
|
||||
@ -68,7 +69,6 @@ class ChatDashboard(BaseChat):
|
||||
# TODO 修复流程
|
||||
print(str(e))
|
||||
|
||||
|
||||
chart_datas.append(chart_data)
|
||||
|
||||
report_data.conv_uid = self.chat_session_id
|
||||
@ -77,5 +77,3 @@ class ChatDashboard(BaseChat):
|
||||
report_data.charts = chart_datas
|
||||
|
||||
return report_data
|
||||
|
||||
|
||||
|
@ -12,11 +12,7 @@ class ChartData(BaseModel):
|
||||
|
||||
|
||||
class ReportData(BaseModel):
|
||||
conv_uid:str
|
||||
template_name:str
|
||||
template_introduce:str
|
||||
conv_uid: str
|
||||
template_name: str
|
||||
template_introduce: str
|
||||
charts: List[ChartData]
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,9 +10,9 @@ from pilot.configs.model_config import LOGDIR
|
||||
|
||||
class ChartItem(NamedTuple):
|
||||
sql: str
|
||||
title:str
|
||||
title: str
|
||||
thoughts: str
|
||||
showcase:str
|
||||
showcase: str
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
|
||||
@ -28,7 +28,11 @@ class ChatDashboardOutputParser(BaseOutputParser):
|
||||
response = json.loads(clean_str)
|
||||
chart_items = List[ChartItem]
|
||||
for item in response:
|
||||
chart_items.append(ChartItem(item["sql"], item["title"], item["thoughts"], item["showcase"]))
|
||||
chart_items.append(
|
||||
ChartItem(
|
||||
item["sql"], item["title"], item["thoughts"], item["showcase"]
|
||||
)
|
||||
)
|
||||
return chart_items
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
|
@ -24,12 +24,14 @@ give {dialect} data analysis SQL, analysis title, display method and analytical
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = [{
|
||||
"sql": "data analysis SQL",
|
||||
"title": "Data Analysis Title",
|
||||
"showcase": "What type of charts to show",
|
||||
"thoughts": "Current thinking and value of data analysis"
|
||||
}]
|
||||
RESPONSE_FORMAT = [
|
||||
{
|
||||
"sql": "data analysis SQL",
|
||||
"title": "Data Analysis Title",
|
||||
"showcase": "What type of charts to show",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
}
|
||||
]
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
|
@ -21,9 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input
|
||||
):
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbExecute,
|
||||
|
@ -19,9 +19,7 @@ class ChatWithDbQA(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, db_name, user_input
|
||||
):
|
||||
def __init__(self, chat_session_id, db_name, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
@ -63,5 +61,3 @@ class ChatWithDbQA(BaseChat):
|
||||
"table_info": table_info,
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
||||
|
@ -34,7 +34,6 @@ RESPONSE_FORMAT = {
|
||||
}
|
||||
|
||||
|
||||
|
||||
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
### Whether the model service is streaming output
|
||||
@ -47,8 +46,10 @@ prompt = PromptTemplate(
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
example=example
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||
),
|
||||
example=example,
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
@ -27,9 +27,7 @@ class ChatNewKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self, chat_session_id, user_input, knowledge_name
|
||||
):
|
||||
def __init__(self, chat_session_id, user_input, knowledge_name):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatNewKnowledge,
|
||||
@ -56,7 +54,6 @@ class ChatNewKnowledge(BaseChat):
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatNewKnowledge.value
|
||||
|
@ -29,7 +29,7 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input):
|
||||
def __init__(self, chat_session_id, user_input):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatDefaultKnowledge,
|
||||
@ -59,8 +59,6 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
)
|
||||
return input_values
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatDefaultKnowledge.value
|
||||
|
@ -36,7 +36,6 @@ class InnerChatDBSummary(BaseChat):
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.InnerChatDBSummary.value
|
||||
|
@ -61,7 +61,6 @@ class ChatUrlKnowledge(BaseChat):
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatUrlKnowledge.value
|
||||
|
@ -29,7 +29,7 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, knowledge_space):
|
||||
def __init__(self, chat_session_id, user_input, knowledge_space):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
@ -59,8 +59,6 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
return input_values
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value
|
||||
|
@ -85,7 +85,6 @@ class OnceConversation:
|
||||
self.messages.clear()
|
||||
self.session_id = None
|
||||
|
||||
|
||||
def get_user_message(self):
|
||||
for once in self.messages:
|
||||
if isinstance(once, HumanMessage):
|
||||
|
@ -13,5 +13,3 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
del load_dotenv
|
||||
|
||||
|
||||
|
@ -122,7 +122,7 @@ app.add_middleware(
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
@ -107,12 +107,16 @@ knowledge_qa_type_list = [
|
||||
add_knowledge_base_dialogue,
|
||||
]
|
||||
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||
*args,
|
||||
**kwargs,
|
||||
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css",
|
||||
)
|
||||
|
||||
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
app = FastAPI()
|
||||
@ -360,14 +364,18 @@ def http_bot(
|
||||
response = chat.stream_call()
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
state.messages[-1][-1] =msg
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
state.messages[-1][-1] = msg
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
chat.memory.append(chat.current_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||
state.messages[-1][
|
||||
-1
|
||||
] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
||||
@ -693,8 +701,12 @@ def signal_handler(sig, frame):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
||||
parser.add_argument(
|
||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"-new", "--new", action="store_true", help="enable new http mode"
|
||||
)
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@ -702,27 +714,24 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
|
||||
if args.new:
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
else:
|
||||
### Compatibility mode starts the old version server by default
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
concurrency_count=args.concurrency_count,
|
||||
status_update_rate=10,
|
||||
api_open=False,
|
||||
).launch(
|
||||
server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -38,7 +38,9 @@ class KnowledgeEmbedding:
|
||||
return self.knowledge_embedding_client.read_batch()
|
||||
|
||||
def init_knowledge_embedding(self):
|
||||
return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config)
|
||||
return get_knowledge_embedding(
|
||||
self.file_type.upper(), self.file_path, self.vector_store_config
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
vector_client = VectorStoreConnector(
|
||||
|
Loading…
Reference in New Issue
Block a user