mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
fix(model): Fix apiserver error (#2605)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dbgpt"
|
||||
version = "0.7.0"
|
||||
version = "0.7.1"
|
||||
description = """DB-GPT is an experimental open-source project that uses localized GPT \
|
||||
large models to interact with your data and environment. With this solution, you can be\
|
||||
assured that there is no risk of data leakage, and your data is 100% private and secure.\
|
||||
@@ -83,6 +83,7 @@ framework = [
|
||||
"gTTS==2.3.1",
|
||||
"pymysql",
|
||||
"jsonschema",
|
||||
"python-jsonpath",
|
||||
# TODO move transformers to default
|
||||
"tokenizers>=0.14",
|
||||
"alembic==1.12.0",
|
||||
|
@@ -1 +1 @@
|
||||
version = "0.7.0"
|
||||
version = "0.7.1"
|
||||
|
@@ -249,6 +249,34 @@ class MediaContent:
|
||||
"""Create a MediaContent object from thinking."""
|
||||
return cls(type="thinking", object=MediaObject(data=text, format="text"))
|
||||
|
||||
@classmethod
|
||||
def parse_content(
|
||||
cls,
|
||||
content: Union[
|
||||
"MediaContent", List["MediaContent"], Dict[str, Any], List[Dict[str, Any]]
|
||||
],
|
||||
) -> Union["MediaContent", List["MediaContent"]]:
|
||||
def _parse_dict(obj_dict: Union[MediaContent, Dict[str, Any]]) -> MediaContent:
|
||||
if isinstance(obj_dict, MediaContent):
|
||||
return obj_dict
|
||||
content_object = obj_dict.get("object")
|
||||
if not content_object:
|
||||
raise ValueError(f"Failed to parse {obj_dict}, no object found")
|
||||
if isinstance(content_object, dict):
|
||||
content_object = MediaObject(
|
||||
data=content_object.get("data"),
|
||||
format=content_object.get("format", "text"),
|
||||
)
|
||||
return cls(
|
||||
type=obj_dict.get("type", "text"),
|
||||
object=content_object,
|
||||
)
|
||||
|
||||
if isinstance(content, list):
|
||||
return [_parse_dict(c) for c in content]
|
||||
else:
|
||||
return _parse_dict(content)
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""Get the text."""
|
||||
if self.type == MediaContentType.TEXT:
|
||||
@@ -322,7 +350,11 @@ class ModelOutput:
|
||||
self,
|
||||
error_code: int,
|
||||
text: Optional[str] = None,
|
||||
content: Optional[MediaContent] = None,
|
||||
content: Optional[
|
||||
Union[
|
||||
MediaContent, List[MediaContent], Dict[str, Any], List[Dict[str, Any]]
|
||||
]
|
||||
] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if text is not None and content is not None:
|
||||
@@ -330,7 +362,7 @@ class ModelOutput:
|
||||
elif text is not None:
|
||||
self.content = MediaContent.build_text(text)
|
||||
elif content is not None:
|
||||
self.content = content
|
||||
self.content = MediaContent.parse_content(content)
|
||||
else:
|
||||
raise ValueError("Must pass either text or content")
|
||||
self.error_code = error_code
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
@@ -14,6 +15,8 @@ from dbgpt.core.interface.storage import (
|
||||
StorageItem,
|
||||
)
|
||||
|
||||
from ..schema.types import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Message object."""
|
||||
@@ -179,9 +182,145 @@ class ModelMessage(BaseModel):
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_openai_system_message(
|
||||
message: ChatCompletionMessageParam,
|
||||
) -> List[ModelMessage]:
|
||||
"""Parse system message from OpenAI format.
|
||||
|
||||
Args:
|
||||
message (ChatCompletionMessageParam): The OpenAI message
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The model messages
|
||||
"""
|
||||
content = message["content"]
|
||||
result = []
|
||||
if isinstance(content, str):
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
|
||||
)
|
||||
elif isinstance(content, Iterable):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=item)
|
||||
)
|
||||
elif isinstance(item, dict) and "type" in item:
|
||||
type = item["type"]
|
||||
if type == "text" and "text" in item:
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.SYSTEM, content=item["text"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown message type: {item} of system message"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {item} of system message")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {message} of system message")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_openai_user_message(
|
||||
message: ChatCompletionMessageParam,
|
||||
) -> List[ModelMessage]:
|
||||
"""Parse user message from OpenAI format.
|
||||
|
||||
Args:
|
||||
message (ChatCompletionMessageParam): The OpenAI message
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The model messages
|
||||
"""
|
||||
result = []
|
||||
content = message["content"]
|
||||
if isinstance(content, str):
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
)
|
||||
elif isinstance(content, Iterable):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=item)
|
||||
)
|
||||
elif isinstance(item, dict) and "type" in item:
|
||||
type = item["type"]
|
||||
if type == "text" and "text" in item:
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.HUMAN, content=item["text"]
|
||||
)
|
||||
)
|
||||
elif type == "image_url":
|
||||
raise ValueError("Image message is not supported now")
|
||||
elif type == "input_audio":
|
||||
raise ValueError("Input audio message is not supported now")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown message type: {item} of human message"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {item} of humman message")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {message} of humman message")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_assistant_message(
|
||||
message: ChatCompletionMessageParam,
|
||||
) -> List[ModelMessage]:
|
||||
"""Parse assistant message from OpenAI format.
|
||||
|
||||
Args:
|
||||
message (ChatCompletionMessageParam): The OpenAI message
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The model messages
|
||||
"""
|
||||
result = []
|
||||
content = message["content"]
|
||||
if isinstance(content, str):
|
||||
result.append(ModelMessage(role=ModelMessageRoleType.AI, content=content))
|
||||
elif isinstance(content, Iterable):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content=item)
|
||||
)
|
||||
elif isinstance(item, dict) and "type" in item:
|
||||
type = item["type"]
|
||||
if type == "text" and "text" in item:
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.AI, content=item["text"]
|
||||
)
|
||||
)
|
||||
elif type == "refusal" and "refusal" in item:
|
||||
result.append(
|
||||
ModelMessage(
|
||||
role=ModelMessageRoleType.AI, content=item["refusal"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown message type: {item} of assistant message"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown message type: {item} of assistant message"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {message} of assistant message")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_openai_messages(
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
messages: Union[str, List[ChatCompletionMessageParam]],
|
||||
) -> List["ModelMessage"]:
|
||||
"""Openai message format to current ModelMessage format."""
|
||||
if isinstance(messages, str):
|
||||
@@ -189,19 +328,18 @@ class ModelMessage(BaseModel):
|
||||
result = []
|
||||
for message in messages:
|
||||
msg_role = message["role"]
|
||||
content = message["content"]
|
||||
if msg_role == "system":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
|
||||
)
|
||||
result.extend(ModelMessage._parse_openai_system_message(message))
|
||||
elif msg_role == "user":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
)
|
||||
result.extend(ModelMessage._parse_openai_user_message(message))
|
||||
elif msg_role == "assistant":
|
||||
result.append(
|
||||
ModelMessage(role=ModelMessageRoleType.AI, content=content)
|
||||
result.extend(ModelMessage._parse_assistant_message(message))
|
||||
elif msg_role == "function":
|
||||
raise ValueError(
|
||||
"Function role is not supported in ModelMessage format"
|
||||
)
|
||||
elif msg_role == "tool":
|
||||
raise ValueError("Tool role is not supported in ModelMessage format")
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
return result
|
||||
|
@@ -7,6 +7,8 @@ from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
|
||||
from .types import ChatCompletionMessageParam
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -47,14 +49,13 @@ class Result(BaseModel, Generic[T]):
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
_ChatCompletionMessageType = Union[str, List[Dict[str, str]], List[str]]
|
||||
|
||||
|
||||
class APIChatCompletionRequest(BaseModel):
|
||||
"""Chat completion request entity."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
messages: _ChatCompletionMessageType = Field(..., description="User input messages")
|
||||
messages: Union[str, List[ChatCompletionMessageParam]] = Field(
|
||||
..., description="User input messages"
|
||||
)
|
||||
temperature: Optional[float] = Field(
|
||||
0.7,
|
||||
description="What sampling temperature to use, between 0 and 2. Higher values "
|
||||
|
214
packages/dbgpt-core/src/dbgpt/core/schema/types.py
Normal file
214
packages/dbgpt-core/src/dbgpt/core/schema/types.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Adapted from OpenAI API types.
|
||||
|
||||
All types are adapted from the OpenAI API types. They will be used to provide the OpenAI
|
||||
compatible types for the DB-GPT apiserver.
|
||||
|
||||
Note: the are not the internal types of the DB-GPT.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Optional, TypeAlias, Union
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
class ChatCompletionContentPartTextParam(TypedDict, total=False):
|
||||
text: Required[str]
|
||||
"""The text content."""
|
||||
|
||||
type: Required[Literal["text"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class ImageURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""Either a URL of the image or the base64 encoded image data."""
|
||||
|
||||
detail: Literal["auto", "low", "high"]
|
||||
"""Specifies the detail level of the image.
|
||||
|
||||
Learn more in the
|
||||
[Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding).
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageParam(TypedDict, total=False):
|
||||
image_url: Required[ImageURL]
|
||||
|
||||
type: Required[Literal["image_url"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class InputAudio(TypedDict, total=False):
|
||||
data: Required[str]
|
||||
"""Base64 encoded audio data."""
|
||||
|
||||
format: Required[Literal["wav", "mp3"]]
|
||||
"""The format of the encoded audio data. Currently supports "wav" and "mp3"."""
|
||||
|
||||
|
||||
class ChatCompletionContentPartInputAudioParam(TypedDict, total=False):
|
||||
input_audio: Required[InputAudio]
|
||||
|
||||
type: Required[Literal["input_audio"]]
|
||||
"""The type of the content part. Always `input_audio`."""
|
||||
|
||||
|
||||
class Function(TypedDict, total=False):
|
||||
arguments: Required[str]
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: Required[str]
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class FunctionCall(TypedDict, total=False):
|
||||
arguments: Required[str]
|
||||
"""
|
||||
The arguments to call the function with, as generated by the model in JSON
|
||||
format. Note that the model does not always generate valid JSON, and may
|
||||
hallucinate parameters not defined by your function schema. Validate the
|
||||
arguments in your code before calling your function.
|
||||
"""
|
||||
|
||||
name: Required[str]
|
||||
"""The name of the function to call."""
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallParam(TypedDict, total=False):
|
||||
id: Required[str]
|
||||
"""The ID of the tool call."""
|
||||
|
||||
function: Required[Function]
|
||||
"""The function that the model called."""
|
||||
|
||||
type: Required[Literal["function"]]
|
||||
"""The type of the tool. Currently, only `function` is supported."""
|
||||
|
||||
|
||||
class Audio(TypedDict, total=False):
|
||||
id: Required[str]
|
||||
"""Unique identifier for a previous audio response from the model."""
|
||||
|
||||
|
||||
class ChatCompletionContentPartRefusalParam(TypedDict, total=False):
|
||||
refusal: Required[str]
|
||||
"""The refusal message generated by the model."""
|
||||
|
||||
type: Required[Literal["refusal"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
]
|
||||
ContentArrayOfContentPart: TypeAlias = Union[
|
||||
ChatCompletionContentPartTextParam, ChatCompletionContentPartRefusalParam
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionSystemMessageParam(TypedDict, total=False):
|
||||
content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]]
|
||||
"""The contents of the system message."""
|
||||
|
||||
role: Required[Literal["system"]]
|
||||
"""The role of the messages author, in this case `system`."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the same
|
||||
role.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionUserMessageParam(TypedDict, total=False):
|
||||
content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
|
||||
"""The contents of the user message."""
|
||||
|
||||
role: Required[Literal["user"]]
|
||||
"""The role of the messages author, in this case `user`."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the same
|
||||
role.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionAssistantMessageParam(TypedDict, total=False):
|
||||
role: Required[Literal["assistant"]]
|
||||
"""The role of the messages author, in this case `assistant`."""
|
||||
|
||||
audio: Optional[Audio]
|
||||
"""Data about a previous audio response from the model.
|
||||
|
||||
[Learn more](https://platform.openai.com/docs/guides/audio).
|
||||
"""
|
||||
|
||||
content: Union[str, Iterable[ContentArrayOfContentPart], None]
|
||||
"""The contents of the assistant message.
|
||||
|
||||
Required unless `tool_calls` or `function_call` is specified.
|
||||
"""
|
||||
|
||||
function_call: Optional[FunctionCall]
|
||||
"""Deprecated and replaced by `tool_calls`.
|
||||
|
||||
The name and arguments of a function that should be called, as generated by the
|
||||
model.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the same
|
||||
role.
|
||||
"""
|
||||
|
||||
refusal: Optional[str]
|
||||
"""The refusal message by the assistant."""
|
||||
|
||||
tool_calls: Iterable[ChatCompletionMessageToolCallParam]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
class ChatCompletionToolMessageParam(TypedDict, total=False):
|
||||
content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]]
|
||||
"""The contents of the tool message."""
|
||||
|
||||
role: Required[Literal["tool"]]
|
||||
"""The role of the messages author, in this case `tool`."""
|
||||
|
||||
tool_call_id: Required[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
|
||||
class ChatCompletionFunctionMessageParam(TypedDict, total=False):
|
||||
content: Required[Optional[str]]
|
||||
"""The contents of the function message."""
|
||||
|
||||
name: Required[str]
|
||||
"""The name of the function to call."""
|
||||
|
||||
role: Required[Literal["function"]]
|
||||
"""The role of the messages author, in this case `function`."""
|
||||
|
||||
|
||||
# from openai.types.chat import ChatCompletionMessageParam
|
||||
OpenAIChatCompletionMessageParam: TypeAlias = Union[
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
]
|
||||
|
||||
ChatCompletionMessageParam = OpenAIChatCompletionMessageParam
|
@@ -5,6 +5,7 @@ Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
@@ -50,9 +51,13 @@ from dbgpt.model.cluster.registry import ModelRegistry
|
||||
from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType
|
||||
from dbgpt.util.chat_util import transform_to_sse
|
||||
from dbgpt.util.fastapi import create_app
|
||||
from dbgpt.util.tracer import initialize_tracer, root_tracer
|
||||
from dbgpt.util.tracer import initialize_tracer, root_tracer, trace
|
||||
from dbgpt.util.tracer.tracer_impl import TracerParameters
|
||||
from dbgpt.util.utils import LoggingParameters, setup_logging
|
||||
from dbgpt.util.utils import (
|
||||
LoggingParameters,
|
||||
logging_str_to_uvicorn_level,
|
||||
setup_logging,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -319,28 +324,49 @@ class APIServer(BaseComponent):
|
||||
)
|
||||
yield transform_to_sse(chunk)
|
||||
|
||||
delta_text = ""
|
||||
previous_text = ""
|
||||
thinking_text = ""
|
||||
previous_thinking_text = ""
|
||||
full_text = ""
|
||||
|
||||
span = root_tracer.start_span(
|
||||
"API.chat_completion_stream_generator",
|
||||
metadata={
|
||||
"model": model_name,
|
||||
"params": json.dumps(params, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
|
||||
async for model_output in worker_manager.generate_stream(params):
|
||||
model_output: ModelOutput = model_output
|
||||
if model_output.error_code != 0:
|
||||
yield transform_to_sse(model_output.to_dict())
|
||||
yield transform_to_sse("[DONE]")
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
if model_output.has_text:
|
||||
full_text = model_output.text
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
delta_text = decoded_unicode[len(previous_text) :]
|
||||
previous_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_text)
|
||||
else previous_text
|
||||
)
|
||||
if model_output.has_thinking:
|
||||
decoded_unicode = model_output.thinking_text.replace("\ufffd", "")
|
||||
thinking_text = decoded_unicode[len(previous_thinking_text) :]
|
||||
previous_thinking_text = (
|
||||
decoded_unicode
|
||||
if len(decoded_unicode) > len(previous_thinking_text)
|
||||
else previous_thinking_text
|
||||
)
|
||||
|
||||
if len(delta_text) == 0:
|
||||
if not delta_text:
|
||||
delta_text = None
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
if not thinking_text:
|
||||
thinking_text = None
|
||||
|
||||
has_usage = False
|
||||
if model_output.usage:
|
||||
curr_usage = UsageInfo.model_validate(model_output.usage)
|
||||
@@ -353,17 +379,29 @@ class APIServer(BaseComponent):
|
||||
+ curr_usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
has_usage = False
|
||||
usage = UsageInfo()
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name, usage=usage
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=delta_text, reasoning_content=thinking_text
|
||||
),
|
||||
finish_reason=model_output.finish_reason,
|
||||
)
|
||||
if delta_text is None:
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=id, choices=[choice_data], model=model_name or "", usage=usage
|
||||
)
|
||||
if delta_text is None and thinking_text is None:
|
||||
if model_output.finish_reason is not None:
|
||||
finish_stream_events.append(chunk)
|
||||
if not has_usage:
|
||||
continue
|
||||
|
||||
yield transform_to_sse(chunk)
|
||||
span.end(
|
||||
metadata={
|
||||
"full_text": full_text,
|
||||
}
|
||||
)
|
||||
|
||||
# There is not "content" field in the last delta message, so exclude_none to
|
||||
# exclude field "content".
|
||||
@@ -371,6 +409,7 @@ class APIServer(BaseComponent):
|
||||
yield transform_to_sse(finish_chunk)
|
||||
yield transform_to_sse("[DONE]")
|
||||
|
||||
@trace()
|
||||
async def chat_completion_generate(
|
||||
self, model_name: str, params: Dict[str, Any], n: int
|
||||
) -> ChatCompletionResponse:
|
||||
@@ -398,7 +437,11 @@ class APIServer(BaseComponent):
|
||||
choices.append(
|
||||
ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role="assistant", content=model_output.text),
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=model_output.text,
|
||||
reasoning_content=model_output.thinking_text,
|
||||
),
|
||||
finish_reason=model_output.finish_reason or "stop",
|
||||
)
|
||||
)
|
||||
@@ -837,7 +880,12 @@ def initialize_apiserver(
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
||||
message = ""
|
||||
for error in exc.errors():
|
||||
loc = ".".join(list(map(str, error.get("loc"))))
|
||||
message += loc + ":" + error.get("msg") + ";"
|
||||
logger.warning(message)
|
||||
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, message)
|
||||
|
||||
_initialize_all(apiserver_params.controller_addr, system_app)
|
||||
|
||||
@@ -852,11 +900,14 @@ def initialize_apiserver(
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
log_level = "info"
|
||||
if log_config:
|
||||
log_level = logging_str_to_uvicorn_level(log_config.level)
|
||||
uvicorn.run(
|
||||
cors_app,
|
||||
host=apiserver_params.host,
|
||||
port=apiserver_params.port,
|
||||
log_level="info",
|
||||
log_level=log_level,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -94,6 +94,44 @@ def trace_cli_group():
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
@click.option(
|
||||
"-j",
|
||||
"--json_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Extract specific JSON path from spans using JSONPath syntax. Example: "
|
||||
"'$.metadata.messages[0].content'"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"-sj",
|
||||
"search_json_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Extract specific JSON path from spans using JSONPath syntax. Example: "
|
||||
"'$.metadata.messages[0].content'"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"-jm",
|
||||
"--json_path_match",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help=("To filter the data after extracting the JSON path"),
|
||||
)
|
||||
@click.option(
|
||||
"--value",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Just show the value after extracting the JSON path",
|
||||
)
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, readable=True))
|
||||
def list(
|
||||
trace_id: str,
|
||||
@@ -106,6 +144,10 @@ def list(
|
||||
end_time: str,
|
||||
desc: bool,
|
||||
output: str,
|
||||
json_path: str,
|
||||
search_json_path: str,
|
||||
json_path_match: str,
|
||||
value: bool = False,
|
||||
files=None,
|
||||
):
|
||||
"""List your trace spans"""
|
||||
@@ -141,8 +183,84 @@ def list(
|
||||
# Sort spans based on the start time
|
||||
spans = sorted(
|
||||
spans, key=lambda span: _parse_datetime(span["start_time"]), reverse=desc
|
||||
)[:limit]
|
||||
)
|
||||
|
||||
# Handle JSON path extraction if specified
|
||||
if json_path:
|
||||
try:
|
||||
# Try to import python-jsonpath
|
||||
try:
|
||||
import jsonpath
|
||||
except ImportError:
|
||||
print("'python-jsonpath' library is required for --json_path option.")
|
||||
print("Please install it with: pip install python-jsonpath")
|
||||
return
|
||||
|
||||
# Add root prefix $ if not present
|
||||
if not json_path.startswith("$"):
|
||||
json_path = "$." + json_path
|
||||
if search_json_path and not search_json_path.startswith("$"):
|
||||
search_json_path = "$." + search_json_path
|
||||
|
||||
# Process all spans
|
||||
extracted_data = []
|
||||
|
||||
for span in spans:
|
||||
try:
|
||||
# Find all matches using python-jsonpath
|
||||
results = jsonpath.findall(json_path, span)
|
||||
if results and search_json_path and json_path_match:
|
||||
search_results = jsonpath.findall(search_json_path, span)
|
||||
if not search_results or json_path_match not in search_results:
|
||||
results = []
|
||||
|
||||
if results:
|
||||
extracted_data.append(
|
||||
{
|
||||
"trace_id": span.get("trace_id"),
|
||||
"span_id": span.get("span_id"),
|
||||
"extracted_values": results,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# Skip spans that cause errors
|
||||
logger.debug(
|
||||
f"Error extracting from span {span.get('trace_id')}: {e}"
|
||||
)
|
||||
continue
|
||||
extracted_data = extracted_data[:limit]
|
||||
if value:
|
||||
extracted_data = [
|
||||
item["extracted_values"]
|
||||
for item in extracted_data
|
||||
if item["extracted_values"]
|
||||
]
|
||||
|
||||
if output == "json" and extracted_data:
|
||||
print(json.dumps(extracted_data, ensure_ascii=False, indent=2))
|
||||
elif extracted_data:
|
||||
for item in extracted_data:
|
||||
if not value:
|
||||
for it in item["extracted_values"]:
|
||||
show_value = json.dumps(it, ensure_ascii=False, indent=2)
|
||||
print(
|
||||
f"Trace ID: {item['trace_id']}, Span ID: "
|
||||
f"{item['span_id']}, \nValue: {show_value}"
|
||||
)
|
||||
print("=" * 80)
|
||||
else:
|
||||
for it in item:
|
||||
if isinstance(it, dict):
|
||||
print(json.dumps(it, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(it)
|
||||
print("=" * 80)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error while processing JSONPath: {e}")
|
||||
return
|
||||
|
||||
spans = spans[:limit]
|
||||
table = PrettyTable(
|
||||
["Trace ID", "Span ID", "Operation Name", "Conversation UID"],
|
||||
)
|
||||
@@ -423,6 +541,8 @@ def read_spans_from_files(files=None) -> Iterable[Dict]:
|
||||
for filename in glob.glob(filepath):
|
||||
with open(filename, "r") as file:
|
||||
for line in file:
|
||||
if not line.strip():
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
|
@@ -204,13 +204,39 @@ class TracerManager:
|
||||
root_tracer: TracerManager = TracerManager()
|
||||
|
||||
|
||||
def trace(operation_name: Optional[str] = None, **trace_kwargs):
|
||||
def trace(
|
||||
operation_name: Optional[str] = None, exclude_params: list = None, **trace_kwargs
|
||||
):
|
||||
"""Decorator for tracing function calls.
|
||||
|
||||
Args:
|
||||
operation_name: Optional name of the operation. If not provided, it will be
|
||||
derived from the function name.
|
||||
exclude_params: List of parameter names to exclude from metadata extraction.
|
||||
**trace_kwargs: Additional keyword arguments for the tracer.
|
||||
|
||||
Returns:
|
||||
Decorated function with tracing functionality.
|
||||
"""
|
||||
if exclude_params is None:
|
||||
exclude_params = []
|
||||
|
||||
# Always exclude 'self' and 'cls' by default
|
||||
default_exclude = ["self", "cls"]
|
||||
exclude_params = default_exclude + exclude_params
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
name = (
|
||||
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||
)
|
||||
|
||||
# Extract function parameters as metadata if not provided in trace_kwargs
|
||||
if "metadata" not in trace_kwargs:
|
||||
metadata = _extract_function_params(func, args, kwargs, exclude_params)
|
||||
trace_kwargs["metadata"] = metadata
|
||||
|
||||
with root_tracer.start_span(name, **trace_kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -219,6 +245,12 @@ def trace(operation_name: Optional[str] = None, **trace_kwargs):
|
||||
name = (
|
||||
operation_name if operation_name else _parse_operation_name(func, *args)
|
||||
)
|
||||
|
||||
# Extract function parameters as metadata if not provided in trace_kwargs
|
||||
if "metadata" not in trace_kwargs:
|
||||
metadata = _extract_function_params(func, args, kwargs, exclude_params)
|
||||
trace_kwargs["metadata"] = metadata
|
||||
|
||||
with root_tracer.start_span(name, **trace_kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@@ -230,6 +262,58 @@ def trace(operation_name: Optional[str] = None, **trace_kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def _extract_function_params(func, args, kwargs, exclude_params):
|
||||
"""Extract function parameters as metadata.
|
||||
|
||||
Args:
|
||||
func: The function being traced.
|
||||
args: Positional arguments passed to the function.
|
||||
kwargs: Keyword arguments passed to the function.
|
||||
exclude_params: List of parameter names to exclude.
|
||||
|
||||
Returns:
|
||||
Dict containing parameter names and their values.
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# Get function signature
|
||||
sig = inspect.signature(func)
|
||||
parameters = list(sig.parameters.items())
|
||||
|
||||
# Process positional arguments
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(parameters):
|
||||
param_name = parameters[i][0]
|
||||
if param_name not in exclude_params:
|
||||
try:
|
||||
# Try to make the value JSON serializable by converting to str if
|
||||
# needed
|
||||
metadata[param_name] = (
|
||||
str(arg)
|
||||
if not isinstance(arg, (str, int, float, bool, type(None)))
|
||||
else arg
|
||||
)
|
||||
except Exception:
|
||||
metadata[param_name] = f"<non-serializable: {type(arg).__name__}>"
|
||||
|
||||
# Process keyword arguments
|
||||
for param_name, param_value in kwargs.items():
|
||||
if param_name not in exclude_params:
|
||||
try:
|
||||
# Try to make the value JSON serializable by converting to str if needed
|
||||
metadata[param_name] = (
|
||||
str(param_value)
|
||||
if not isinstance(param_value, (str, int, float, bool, type(None)))
|
||||
else param_value
|
||||
)
|
||||
except Exception:
|
||||
metadata[param_name] = (
|
||||
f"<non-serializable: {type(param_value).__name__}>"
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _parse_operation_name(func, *args):
|
||||
self_name = None
|
||||
if inspect.signature(func).parameters.get("self"):
|
||||
|
Reference in New Issue
Block a user