mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 21:39:33 +00:00
feat(model): API support reasoning (#2409)
This commit is contained in:
@@ -245,7 +245,7 @@ def _run_flow_cmd_local(
|
||||
if not out.success:
|
||||
cl.error(out.text)
|
||||
else:
|
||||
cl.print(out.text, end="")
|
||||
cl.print(out.gen_text_with_thinking(), end="")
|
||||
except Exception as e:
|
||||
cl.error(f"Failed to run flow: {e}", exit_code=1)
|
||||
finally:
|
||||
@@ -436,7 +436,7 @@ def _run_flow_chat_local(
|
||||
cl.error(f"Error: {out.text}")
|
||||
raise Exception(out.text)
|
||||
else:
|
||||
yield out.text
|
||||
yield out.gen_text_with_thinking()
|
||||
|
||||
async def _call(_call_body: Dict[str, Any]):
|
||||
nonlocal dag, dag_metadata
|
||||
@@ -466,6 +466,9 @@ def _run_flow_chat_stream(
|
||||
async for out in client.chat_stream(**_call_body):
|
||||
if out.choices:
|
||||
text = out.choices[0].delta.content
|
||||
reasoning_content = out.choices[0].delta.reasoning_content
|
||||
if reasoning_content:
|
||||
yield reasoning_content
|
||||
if text:
|
||||
yield text
|
||||
|
||||
@@ -482,6 +485,13 @@ def _run_flow_chat(
|
||||
res = await client.chat(**_call_body)
|
||||
if res.choices:
|
||||
text = res.choices[0].message.content
|
||||
if res.choices[0].message.reasoning_content:
|
||||
reasoning_content = res.choices[0].message.reasoning_content
|
||||
# For each line, add '>' at the beginning
|
||||
reasoning_content = "\n".join(
|
||||
[f"> {line}" for line in reasoning_content.split("\n")]
|
||||
)
|
||||
text = reasoning_content + "\n\n" + text
|
||||
return text
|
||||
|
||||
loop.run_until_complete(_chat(_call, interactive, json_data))
|
||||
|
@@ -7,32 +7,20 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import File, UploadFile
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
from dbgpt.core.schema.api import APIChatCompletionRequest
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkParameters
|
||||
|
||||
|
||||
class ChatCompletionRequestBody(BaseModel):
|
||||
class ChatCompletionRequestBody(APIChatCompletionRequest):
|
||||
"""ChatCompletion LLM http request body."""
|
||||
|
||||
model: str = Field(
|
||||
..., description="The model name", examples=["gpt-3.5-turbo", "proxyllm"]
|
||||
)
|
||||
messages: Union[str, List[str]] = Field(
|
||||
..., description="User input messages", examples=["Hello", "How are you?"]
|
||||
)
|
||||
stream: bool = Field(default=True, description="Whether return stream")
|
||||
|
||||
temperature: Optional[float] = Field(
|
||||
default=None,
|
||||
description="What sampling temperature to use, between 0 and 2. Higher values "
|
||||
"like 0.8 will make the output more random, "
|
||||
"while lower values like 0.2 will "
|
||||
"make it more focused and deterministic.",
|
||||
)
|
||||
max_new_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens that can be generated in the chat "
|
||||
"completion.",
|
||||
deprecated="'max_new_tokens' is deprecated. Use 'max_tokens' instead.",
|
||||
)
|
||||
conv_uid: Optional[str] = Field(
|
||||
default=None, description="The conversation id of the model inference"
|
||||
@@ -65,6 +53,37 @@ class ChatCompletionRequestBody(BaseModel):
|
||||
default=True, description="response content whether to output vis label"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the messages."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
max_tokens = values.get("max_tokens")
|
||||
max_new_tokens = values.get("max_new_tokens")
|
||||
if max_tokens is None and max_new_tokens is not None:
|
||||
values["max_tokens"] = max_new_tokens
|
||||
return values
|
||||
|
||||
def to_common_llm_http_request_body(self) -> CommonLLMHttpRequestBody:
|
||||
"""Convert to CommonLLMHttpRequestBody."""
|
||||
max_new_tokens = self.max_tokens
|
||||
return CommonLLMHttpRequestBody(
|
||||
model=self.model,
|
||||
messages=self.single_prompt(),
|
||||
stream=self.stream,
|
||||
temperature=self.temperature,
|
||||
max_new_tokens=max_new_tokens,
|
||||
conv_uid=self.conv_uid,
|
||||
span_id=self.span_id,
|
||||
chat_mode=self.chat_mode,
|
||||
chat_param=self.chat_param,
|
||||
user_name=self.user_name,
|
||||
sys_code=self.sys_code,
|
||||
incremental=self.incremental,
|
||||
enable_vis=self.enable_vis,
|
||||
)
|
||||
|
||||
|
||||
class ChatMode(Enum):
|
||||
"""Chat mode."""
|
||||
@@ -74,6 +93,7 @@ class ChatMode(Enum):
|
||||
CHAT_AWEL_FLOW = "chat_flow"
|
||||
CHAT_KNOWLEDGE = "chat_knowledge"
|
||||
CHAT_DATA = "chat_data"
|
||||
CHAT_DB_QA = "chat_with_db_qa"
|
||||
|
||||
|
||||
class AWELTeamModel(BaseModel):
|
||||
|
Reference in New Issue
Block a user