feat(model): API support reasoning (#2409)

This commit is contained in:
Fangyin Cheng
2025-03-07 15:31:12 +08:00
committed by GitHub
parent 2697aba4f5
commit 4e993a2be8
76 changed files with 653 additions and 212 deletions

View File

@@ -245,7 +245,7 @@ def _run_flow_cmd_local(
if not out.success:
cl.error(out.text)
else:
cl.print(out.text, end="")
cl.print(out.gen_text_with_thinking(), end="")
except Exception as e:
cl.error(f"Failed to run flow: {e}", exit_code=1)
finally:
@@ -436,7 +436,7 @@ def _run_flow_chat_local(
cl.error(f"Error: {out.text}")
raise Exception(out.text)
else:
yield out.text
yield out.gen_text_with_thinking()
async def _call(_call_body: Dict[str, Any]):
nonlocal dag, dag_metadata
@@ -466,6 +466,9 @@ def _run_flow_chat_stream(
async for out in client.chat_stream(**_call_body):
if out.choices:
text = out.choices[0].delta.content
reasoning_content = out.choices[0].delta.reasoning_content
if reasoning_content:
yield reasoning_content
if text:
yield text
@@ -482,6 +485,13 @@ def _run_flow_chat(
res = await client.chat(**_call_body)
if res.choices:
text = res.choices[0].message.content
if res.choices[0].message.reasoning_content:
reasoning_content = res.choices[0].message.reasoning_content
# For each line, add '>' at the beginning
reasoning_content = "\n".join(
[f"> {line}" for line in reasoning_content.split("\n")]
)
text = reasoning_content + "\n\n" + text
return text
loop.run_until_complete(_chat(_call, interactive, json_data))

View File

@@ -7,32 +7,20 @@ from typing import Any, Dict, List, Optional, Union
from fastapi import File, UploadFile
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
from dbgpt.core.awel import CommonLLMHttpRequestBody
from dbgpt.core.schema.api import APIChatCompletionRequest
from dbgpt_ext.rag.chunk_manager import ChunkParameters
class ChatCompletionRequestBody(BaseModel):
class ChatCompletionRequestBody(APIChatCompletionRequest):
"""ChatCompletion LLM http request body."""
model: str = Field(
..., description="The model name", examples=["gpt-3.5-turbo", "proxyllm"]
)
messages: Union[str, List[str]] = Field(
..., description="User input messages", examples=["Hello", "How are you?"]
)
stream: bool = Field(default=True, description="Whether return stream")
temperature: Optional[float] = Field(
default=None,
description="What sampling temperature to use, between 0 and 2. Higher values "
"like 0.8 will make the output more random, "
"while lower values like 0.2 will "
"make it more focused and deterministic.",
)
max_new_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens that can be generated in the chat "
"completion.",
deprecated="'max_new_tokens' is deprecated. Use 'max_tokens' instead.",
)
conv_uid: Optional[str] = Field(
default=None, description="The conversation id of the model inference"
@@ -65,6 +53,37 @@ class ChatCompletionRequestBody(BaseModel):
default=True, description="response content whether to output vis label"
)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages."""
if not isinstance(values, dict):
return values
max_tokens = values.get("max_tokens")
max_new_tokens = values.get("max_new_tokens")
if max_tokens is None and max_new_tokens is not None:
values["max_tokens"] = max_new_tokens
return values
def to_common_llm_http_request_body(self) -> CommonLLMHttpRequestBody:
"""Convert to CommonLLMHttpRequestBody."""
max_new_tokens = self.max_tokens
return CommonLLMHttpRequestBody(
model=self.model,
messages=self.single_prompt(),
stream=self.stream,
temperature=self.temperature,
max_new_tokens=max_new_tokens,
conv_uid=self.conv_uid,
span_id=self.span_id,
chat_mode=self.chat_mode,
chat_param=self.chat_param,
user_name=self.user_name,
sys_code=self.sys_code,
incremental=self.incremental,
enable_vis=self.enable_vis,
)
class ChatMode(Enum):
"""Chat mode."""
@@ -74,6 +93,7 @@ class ChatMode(Enum):
CHAT_AWEL_FLOW = "chat_flow"
CHAT_KNOWLEDGE = "chat_knowledge"
CHAT_DATA = "chat_data"
CHAT_DB_QA = "chat_with_db_qa"
class AWELTeamModel(BaseModel):