doc:update dbgpt_demo.mp4

1.update dbgpt_demo.mp4
2.format code
This commit is contained in:
aries_ckt
2023-07-06 13:47:46 +08:00
parent 47595aa10f
commit eb31d5523e
31 changed files with 243 additions and 128 deletions

View File

@@ -452,5 +452,3 @@ class Database:
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]

View File

@@ -86,7 +86,7 @@ if __name__ == "__main__":
# test2.test()
# 定义包含元组的列表
data = [('key1', 'value1'), ('key2', 'value2'), ('key3', 'value3')]
data = [("key1", "value1"), ("key2", "value2"), ("key3", "value3")]
# 使用字典解析将列表转换为字典
result = {k: v for k, v in data}

View File

@@ -13,5 +13,3 @@ if __name__ == "__main__":
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
print(str.find("["))

View File

@@ -1,17 +1,22 @@
from enum import Enum
from typing import List
class Test(Enum):
XXX = ("x", "1", True)
YYY =("Y", "2", False)
YYY = ("Y", "2", False)
ZZZ = ("Z", "3")
def __init__(self, code, v, flag= False):
def __init__(self, code, v, flag=False):
self.code = code
self.v = v
self.flag = flag
class Scene:
def __init__(self, code, name, describe, param_types:List=[], is_inner: bool = False):
def __init__(
self, code, name, describe, param_types: List = [], is_inner: bool = False
):
self.code = code
self.name = name
self.describe = describe
@@ -20,25 +25,68 @@ class Scene:
class ChatScene(Enum):
ChatWithDbExecute = Scene(
"chat_with_db_execute",
"Chat Data",
"Dialogue with your private data through natural language.",
["DB Select"],
)
ChatWithDbQA = Scene(
"chat_with_db_qa",
"Chat Meta Data",
"Have a Professional Conversation with Metadata.",
["DB Select"],
)
ChatExecution = Scene(
"chat_execution",
"Chat Plugin",
"Use tools through dialogue to accomplish your goals.",
["Plugin Select"],
)
ChatDefaultKnowledge = Scene(
"chat_default_knowledge",
"Chat Default Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
)
ChatNewKnowledge = Scene(
"chat_new_knowledge",
"Chat New Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Select"],
)
ChatUrlKnowledge = Scene(
"chat_url_knowledge",
"Chat URL",
"Dialogue through natural language and private documents and knowledge bases.",
["Url Input"],
)
InnerChatDBSummary = Scene(
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
)
ChatWithDbExecute = Scene("chat_with_db_execute", "Chat Data", "Dialogue with your private data through natural language.", ["DB Select"])
ChatWithDbQA = Scene("chat_with_db_qa", "Chat Meta Data", "Have a Professional Conversation with Metadata.", ["DB Select"])
ChatExecution = Scene("chat_execution", "Chat Plugin", "Use tools through dialogue to accomplish your goals.", ["Plugin Select"])
ChatDefaultKnowledge = Scene("chat_default_knowledge", "Chat Default Knowledge", "Dialogue through natural language and private documents and knowledge bases.")
ChatNewKnowledge = Scene("chat_new_knowledge", "Chat New Knowledge", "Dialogue through natural language and private documents and knowledge bases.", ["Knowledge Select"])
ChatUrlKnowledge = Scene("chat_url_knowledge", "Chat URL", "Dialogue through natural language and private documents and knowledge bases.", ["Url Input"])
InnerChatDBSummary = Scene("inner_chat_db_summary", "DB Summary", "Db Summary.", True)
ChatNormal = Scene("chat_normal", "Chat Normal", "Native LLM large model AI dialogue.")
ChatDashboard = Scene("chat_dashboard", "Chat Dashboard", "Provide you with professional analysis reports through natural language.", ["DB Select"])
ChatKnowledge = Scene("chat_knowledge", "Chat Knowledge", "Dialogue through natural language and private documents and knowledge bases.", ["Knowledge Space Select"])
ChatNormal = Scene(
"chat_normal", "Chat Normal", "Native LLM large model AI dialogue."
)
ChatDashboard = Scene(
"chat_dashboard",
"Chat Dashboard",
"Provide you with professional analysis reports through natural language.",
["DB Select"],
)
ChatKnowledge = Scene(
"chat_knowledge",
"Chat Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Space Select"],
)
def scene_value(self):
return self.value.code;
return self.value.code
def scene_name(self):
return self._value_.name;
return self._value_.name
if __name__ == "__main__":
print(ChatScene.ChatWithDbExecute.scene_value())
# print(ChatScene.ChatWithDbExecute.value.describe)
# print(ChatScene.ChatWithDbExecute.value.describe)

View File

@@ -6,7 +6,11 @@ from typing import List
import markdown
from bs4 import BeautifulSoup
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.text_splitter import (
SpacyTextSplitter,
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
)
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register
@@ -44,7 +48,9 @@ class MarkdownEmbedding(SourceEmbedding):
chunk_overlap=100,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
)
return loader.load_and_split(text_splitter)
@register

View File

@@ -47,7 +47,9 @@ class PDFEmbedding(SourceEmbedding):
chunk_overlap=100,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
)
return loader.load_and_split(text_splitter)
@register

View File

@@ -4,7 +4,7 @@ from typing import List
from langchain.document_loaders import UnstructuredPowerPointLoader
from langchain.schema import Document
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from pilot.configs.config import Config
from pilot.embedding_engine import SourceEmbedding, register
@@ -45,7 +45,9 @@ class PPTEmbedding(SourceEmbedding):
chunk_overlap=100,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
)
return loader.load_and_split(text_splitter)
@register

View File

@@ -40,7 +40,9 @@ class URLEmbedding(SourceEmbedding):
chunk_overlap=100,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
)
return loader.load_and_split(text_splitter)
@register

View File

@@ -39,7 +39,9 @@ class WordEmbedding(SourceEmbedding):
chunk_overlap=100,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=50
)
return loader.load_and_split(text_splitter)
@register

View File

@@ -116,7 +116,8 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor = duckdb.connect(duckdb_path).cursor()
if user_name:
cursor.execute(
"SELECT * FROM chat_history where user_name=? order by id desc limit 20", [user_name]
"SELECT * FROM chat_history where user_name=? order by id desc limit 20",
[user_name],
)
else:
cursor.execute("SELECT * FROM chat_history order by id desc limit 20")

View File

@@ -11,7 +11,7 @@ def generate_stream(
"""Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py"""
prompt = params["prompt"]
l_prompt = len(prompt)
prompt= prompt.replace("ai:", "assistant:").replace("human:", "user:")
prompt = prompt.replace("ai:", "assistant:").replace("human:", "user:")
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
stop_str = params.get("stop", None)

View File

@@ -53,7 +53,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
message = ""
for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(code= "E0001", msg=message)
return Result.faild(code="E0001", msg=message)
def __get_conv_user_message(conversations: dict):
@@ -95,8 +95,9 @@ def knowledge_list():
params.update({space.name: space.name})
return params
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list( user_id: str = None):
async def dialogue_list(user_id: str = None):
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
@@ -126,11 +127,10 @@ async def dialogue_scenes():
ChatScene.ChatExecution,
]
for scene in new_modes:
scene_vo = ChatSceneVo(
chat_scene=scene.value(),
scene_name=scene.scene_name(),
scene_describe= scene.describe(),
scene_describe=scene.describe(),
param_title=",".join(scene.param_types()),
)
scene_vos.append(scene_vo)
@@ -138,7 +138,9 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None):
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
):
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)

View File

@@ -130,7 +130,6 @@ class BaseOutputParser(ABC):
return temp_json
def __extract_json(self, s):
temp_json = self.__json_interception(s, True)
if not temp_json:
temp_json = self.__json_interception(s)
@@ -143,10 +142,10 @@ class BaseOutputParser(ABC):
def __json_interception(self, s, is_json_array: bool = False):
if is_json_array:
i = s.find("[")
if i <0:
if i < 0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
@@ -154,13 +153,13 @@ class BaseOutputParser(ABC):
if count == 0:
break
assert count == 0
return s[i: j + 1]
return s[i : j + 1]
else:
i = s.find("{")
if i <0:
if i < 0:
return None
count = 1
for j, c in enumerate(s[i + 1:], start=i + 1):
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
@@ -168,7 +167,7 @@ class BaseOutputParser(ABC):
if count == 0:
break
assert count == 0
return s[i: j + 1]
return s[i : j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""
@@ -185,9 +184,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):]
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):]
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@@ -196,9 +195,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output

View File

@@ -3,7 +3,9 @@ from typing import List
class Scene:
def __init__(self, code, name, describe, param_types: List = [], is_inner: bool = False):
def __init__(
self, code, name, describe, param_types: List = [], is_inner: bool = False
):
self.code = code
self.name = name
self.describe = describe
@@ -12,41 +14,73 @@ class Scene:
class ChatScene(Enum):
ChatWithDbExecute = Scene("chat_with_db_execute", "Chat Data",
"Dialogue with your private data through natural language.", ["DB Select"])
ChatWithDbQA = Scene("chat_with_db_qa", "Chat Meta Data", "Have a Professional Conversation with Metadata.",
["DB Select"])
ChatExecution = Scene("chat_execution", "Plugin", "Use tools through dialogue to accomplish your goals.",
["Plugin Select"])
ChatDefaultKnowledge = Scene("chat_default_knowledge", "Chat Default Knowledge",
"Dialogue through natural language and private documents and knowledge bases.")
ChatNewKnowledge = Scene("chat_new_knowledge", "Chat New Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Select"])
ChatUrlKnowledge = Scene("chat_url_knowledge", "Chat URL",
"Dialogue through natural language and private documents and knowledge bases.",
["Url Input"])
InnerChatDBSummary = Scene("inner_chat_db_summary", "DB Summary", "Db Summary.", True)
ChatWithDbExecute = Scene(
"chat_with_db_execute",
"Chat Data",
"Dialogue with your private data through natural language.",
["DB Select"],
)
ChatWithDbQA = Scene(
"chat_with_db_qa",
"Chat Meta Data",
"Have a Professional Conversation with Metadata.",
["DB Select"],
)
ChatExecution = Scene(
"chat_execution",
"Plugin",
"Use tools through dialogue to accomplish your goals.",
["Plugin Select"],
)
ChatDefaultKnowledge = Scene(
"chat_default_knowledge",
"Chat Default Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
)
ChatNewKnowledge = Scene(
"chat_new_knowledge",
"Chat New Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Select"],
)
ChatUrlKnowledge = Scene(
"chat_url_knowledge",
"Chat URL",
"Dialogue through natural language and private documents and knowledge bases.",
["Url Input"],
)
InnerChatDBSummary = Scene(
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
)
ChatNormal = Scene("chat_normal", "Chat Normal", "Native LLM large model AI dialogue.")
ChatDashboard = Scene("chat_dashboard", "Dashboard",
"Provide you with professional analysis reports through natural language.", ["DB Select"])
ChatKnowledge = Scene("chat_knowledge", "Chat Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Space Select"])
ChatNormal = Scene(
"chat_normal", "Chat Normal", "Native LLM large model AI dialogue."
)
ChatDashboard = Scene(
"chat_dashboard",
"Dashboard",
"Provide you with professional analysis reports through natural language.",
["DB Select"],
)
ChatKnowledge = Scene(
"chat_knowledge",
"Chat Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Space Select"],
)
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value() for item in ChatScene)
def value(self):
return self._value_.code;
return self._value_.code
def scene_name(self):
return self._value_.name;
return self._value_.name
def describe(self):
return self._value_.describe;
return self._value_.describe
def param_types(self):
return self._value_.param_types

View File

@@ -126,7 +126,6 @@ class BaseChat(ABC):
# TODO Retry when server connection error
payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
@@ -142,6 +141,7 @@ class BaseChat(ABC):
return response
else:
from pilot.server.llmserver import worker
return worker.generate_stream_gate(payload)
except Exception as e:
print(traceback.format_exc())
@@ -169,6 +169,7 @@ class BaseChat(ABC):
else:
###TODO no stream mode need independent
from pilot.server.llmserver import worker
output = worker.generate_stream_gate(payload)
for rsp in output:
rsp = rsp.replace(b"\0", b"")

View File

@@ -35,9 +35,7 @@ class ChatDashboard(BaseChat):
current_user_input=user_input,
)
if not db_name:
raise ValueError(
f"{ChatScene.ChatDashboard.value} mode should chose db!"
)
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should chose db!")
self.db_name = db_name
self.report_name = report_name
self.database = CFG.local_db
@@ -47,12 +45,11 @@ class ChatDashboard(BaseChat):
self.dashboard_template = self.__load_dashboard_template(report_name)
def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()
print(current_dir)
current_dir = os.path.dirname(os.path.abspath(__file__))
with open(f"{current_dir}/template/{template_name}/dashboard.json", 'r') as f:
with open(f"{current_dir}/template/{template_name}/dashboard.json", "r") as f:
data = f.read()
return json.loads(data)
@@ -66,7 +63,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(self.db_connect),
"supported_chat_type": self.dashboard_template['supported_chart_type']
"supported_chat_type": self.dashboard_template["supported_chart_type"]
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
@@ -78,16 +75,24 @@ class ChatDashboard(BaseChat):
for chart_item in prompt_response:
try:
datas = self.database.run(self.db_connect, chart_item.sql)
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
chart_name=chart_item.title,
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=datas[0],
values=datas))
chart_datas.append(
ChartData(
chart_uid=str(uuid.uuid1()),
chart_name=chart_item.title,
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=datas[0],
values=datas,
)
)
except Exception as e:
# TODO 修复流程
print(str(e))
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
charts=chart_datas)
return ReportData(
conv_uid=self.chat_session_id,
template_name=self.report_name,
template_introduce=None,
charts=chart_datas,
)

View File

@@ -3,6 +3,7 @@ from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ChartData(BaseModel):
chart_uid: str
chart_name: str
@@ -21,10 +22,11 @@ class ChartData(BaseModel):
"chart_desc": self.chart_desc,
"chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name],
"values": [[str(item) for item in sublist] for sublist in self.values],
"style": self.style
"values": [[str(item) for item in sublist] for sublist in self.values],
"style": self.style,
}
class ReportData(BaseModel):
conv_uid: str
template_name: str
@@ -36,5 +38,5 @@ class ReportData(BaseModel):
"conv_uid": self.conv_uid,
"template_name": self.template_name,
"template_introduce": self.template_introduce,
"charts": [chart.dict() for chart in self.charts]
}
"charts": [chart.dict() for chart in self.charts],
}

View File

@@ -9,6 +9,7 @@ from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.scene.base import ChatScene
class ChartItem(NamedTuple):
sql: str
title: str

View File

@@ -14,8 +14,8 @@ EXAMPLES = [
\"sql\": \"SELECT city FROM user where user_name='test1'\",
}""",
"example": True,
}
}
},
},
]
},
{
@@ -29,10 +29,10 @@ EXAMPLES = [
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""",
"example": True,
}
}
},
},
]
}
},
]
sql_data_example = ExampleSelector(

View File

@@ -28,10 +28,10 @@ class DbChatOutputParser(BaseOutputParser):
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
for key in sorted(response):
if key.strip() == 'sql':
sql =response[key]
if key.strip() == 'thoughts':
thoughts =response[key]
if key.strip() == "sql":
sql = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
return SqlAction(sql, thoughts)
def parse_view_response(self, speak, data) -> str:

View File

@@ -51,6 +51,6 @@ prompt = PromptTemplate(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@@ -14,8 +14,8 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
}
}
},
},
]
},
{
@@ -30,10 +30,10 @@ EXAMPLES = [
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
}
}
},
},
]
}
},
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@@ -29,12 +29,12 @@ class PluginChatOutputParser(BaseOutputParser):
raise ValueError("model server out not fllow the prompt!")
for key in sorted(response):
if key.strip() == 'command':
command =response[key]
if key.strip() == 'thoughts':
thoughts =response[key]
if key.strip() == 'speak':
speak =response[key]
if key.strip() == "command":
command = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
if key.strip() == "speak":
speak = response[key]
return PluginAction(command, speak, thoughts)
def parse_view_response(self, speak, data) -> str:

View File

@@ -32,7 +32,6 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_h
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
@@ -66,7 +65,7 @@ app.add_middleware(
)
app.include_router(api_v1, prefix="/api")
app.include_router(api_v1, prefix="/api")
app.include_router(knowledge_router, prefix="/api")
app.include_router(api_v1)
@@ -77,7 +76,6 @@ app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static"
# app.mount("/chat", StaticFiles(directory=static_file_path + "/chat.html", html=True), name="chat")
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
@@ -91,7 +89,13 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument("-light", "--light", default=False,action="store_true", help="enable light mode")
parser.add_argument(
"-light",
"--light",
default=False,
action="store_true",
help="enable light mode",
)
signal.signal(signal.SIGINT, signal_handler)
# init server config
@@ -101,10 +105,12 @@ if __name__ == "__main__":
if not args.light:
print("Model Unified Deployment Mode!")
from pilot.server.llmserver import worker
worker.start_check()
CFG.NEW_SERVER_MODE = True
else:
CFG.SERVER_LIGHT_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@@ -56,9 +56,11 @@ def space_list(request: KnowledgeSpaceRequest):
def document_add(space_name: str, request: KnowledgeDocumentRequest):
print(f"/document/add params: {space_name}, {request}")
try:
return Result.succ(knowledge_space_service.create_knowledge_document(
space=space_name, request=request
))
return Result.succ(
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)
)
# return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document add error {e}")
@@ -106,9 +108,11 @@ async def document_upload(
KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename
),
)
return Result.succ(knowledge_space_service.create_knowledge_document(
space=space_name, request=request
))
return Result.succ(
knowledge_space_service.create_knowledge_document(
space=space_name, request=request
)
)
# return Result.succ([])
return Result.faild(code="E000X", msg=f"doc_file is None")
except Exception as e:

View File

@@ -27,7 +27,8 @@ from enum import Enum
from pilot.server.knowledge.request.response import (
ChunkQueryResponse,
DocumentQueryResponse, SpaceQueryResponse,
DocumentQueryResponse,
SpaceQueryResponse,
)
knowledge_space_dao = KnowledgeSpaceDao()

View File

@@ -168,7 +168,7 @@ async def api_generate_stream(request: Request):
@app.post("/generate")
def generate(prompt_request: PromptRequest)->str:
def generate(prompt_request: PromptRequest) -> str:
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
@@ -176,7 +176,6 @@ def generate(prompt_request: PromptRequest)->str:
"stop": prompt_request.stop,
}
rsp_str = ""
output = worker.generate_stream_gate(params)
for rsp in output: