fix(model): Fix reasoning output bug (#2393)

Fix reasoning output bug
This commit is contained in:
yyhhyy 2025-03-04 20:42:57 +08:00 committed by GitHub
commit 948a93be32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1334 additions and 2652 deletions

View File

@ -34,4 +34,4 @@ provider = "hf"
# If not provided, the model will be downloaded from the Hugging Face model hub
# uncomment the following line to specify the model path in the local file system
# path = "the-model-path-in-the-local-file-system"
path = "models/BAAI/glm-4-9b-chat-hf"
path = "models/BAAI/bge-large-zh-v1.5"

View File

@ -285,8 +285,23 @@ uv run dbgpt start webserver --config configs/dbgpt-local-vllm.toml
```
</TabItem>
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
```bash
# Use uv to install dependencies needed for llama-cpp
# Install core dependencies and select desired extensions
CMAKE_ARGS="-DGGML_CUDA=ON" uv sync --all-packages \
--extra "base" \
--extra "llama_cpp" \
--extra "rag" \
--extra "storage_chromadb" \
--extra "quant_bnb" \
--extra "dbgpts"
```
Otherwise, run the following command to install dependencies without CUDA support.
```bash
# Use uv to install dependencies needed for llama-cpp
# Install core dependencies and select desired extensions

View File

@ -0,0 +1,4 @@
# DB-GPT Accelerator Module
Building across multiple platforms and hardware is complex, and the DB-GPT Accelerator aims to provide compatibility handling for this, offering as consistent an interface as possible for other core models.

View File

@ -1,5 +1,5 @@
[project]
name = "dbgpt-accelerator"
name = "dbgpt-acc-auto"
version = "0.7.0"
description = "Add your description here"
authors = [
@ -16,21 +16,6 @@ Documentation = "http://docs.dbgpt.cn/docs/overview"
Repository = "https://github.com/eosphoros-ai/DB-GPT.git"
Issues = "https://github.com/eosphoros-ai/DB-GPT/issues"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/dbgpt_accelerator"]
exclude = [
"src/dbgpt_accelerator/**/tests",
"src/dbgpt_accelerator/**/tests/*",
"src/dbgpt_accelerator/tests",
"src/dbgpt_accelerator/tests/*",
"src/dbgpt_accelerator/**/examples",
"src/dbgpt_accelerator/**/examples/*"
]
[project.optional-dependencies]
# Auto install dependencies
auto = [
@ -76,10 +61,10 @@ vllm = [
# Just support GPU version on Linux
"vllm>=0.7.0; sys_platform == 'linux'",
]
#vllm_pascal = [
# vllm_pascal = [
# # https://github.com/sasha0552/pascal-pkgs-ci
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
#]
# ]
quant_bnb = [
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
"accelerate"
@ -103,6 +88,10 @@ quant_gptq = [
"optimum",
"auto-gptq",
]
flash_attn = [
# "torch>=2.2.1",
"dbgpt-acc-flash-attn"
]
[dependency-groups]
auto = [

View File

@ -0,0 +1,3 @@
# DB-GPT-Accelerator for Flash Attention
Wrapper for the Flash Attention module in the DB-GPT-Accelerator.

View File

@ -0,0 +1,24 @@
# Install the flash-attn package for uv
# https://github.com/astral-sh/uv/issues/2252#issuecomment-2624150395
[project]
name = "dbgpt-acc-flash-attn"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = []
[dependency-groups]
build = [
"setuptools>=75.8.0",
]
direct = [
"torch>=2.2.1",
]
main = [
"flash-attn>=2.5.8",
]
[tool.uv]
default-groups = ["build", "direct", "main"]
no-build-isolation-package = ["flash-attn"]

View File

@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">= 3.10"
dependencies = [
"dbgpt-accelerator",
"dbgpt-acc-auto",
"dbgpt",
"dbgpt-ext",
"dbgpt-serve",

View File

@ -102,6 +102,15 @@ class LLMDeployModelParameters(BaseDeployModelParameters, RegisterParameters):
)
},
)
reasoning_model: Optional[bool] = field(
default=None,
metadata={
"help": _(
"Whether the model is a reasoning model. If None, it is "
"automatically determined from model."
)
},
)
@property
def real_provider_model_name(self) -> str:
@ -202,8 +211,10 @@ class BitsandbytesQuantization(BaseHFQuantization):
real_cls = cls
if load_in_8bits:
real_cls = BitsandbytesQuantization8bits
data["type"] = BitsandbytesQuantization8bits.__type__
if load_in_4bits:
real_cls = BitsandbytesQuantization4bits
data["type"] = BitsandbytesQuantization4bits.__type__
real_data = prepare_data_func(real_cls, data)
return real_cls(**real_data)

View File

@ -251,6 +251,27 @@ class LLMModelAdapter(ABC):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError
def is_reasoning_model(
self,
deploy_model_params: LLMDeployModelParameters,
lower_model_name_or_path: Optional[str] = None,
) -> bool:
"""Whether the model is a reasoning model"""
if (
deploy_model_params.reasoning_model is not None
and deploy_model_params.reasoning_model
):
return True
return (
lower_model_name_or_path
and "deepseek" in lower_model_name_or_path
and (
"r1" in lower_model_name_or_path
or "reasoning" in lower_model_name_or_path
or "reasoner" in lower_model_name_or_path
)
)
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""
return False

View File

@ -88,6 +88,15 @@ class HFLLMDeployModelParameters(LLMDeployModelParameters):
"valid_values": ["auto", "float16", "bfloat16", "float", "float32"],
},
)
attn_implementation: Optional[str] = field(
default=None,
metadata={
"help": _(
"The attention implementation, only valid in multi-GPU configuration"
),
"valid_values": ["flash_attention_2"],
},
)
@property
def real_model_path(self) -> Optional[str]:

View File

@ -301,15 +301,15 @@ class LlamaServerParameters(LLMDeployModelParameters):
config_dict[fd.name] = curr_config[fd.name]
if (
"device" in config_dict
and config_dict["device"] == "cuda"
self.real_device
and self.real_device == "cuda"
and ("n_gpu_layers" not in config_dict or not config_dict["n_gpu_layers"])
):
# Set n_gpu_layers to a large number to use all layers
logger.info("Set n_gpu_layers to a large number to use all layers")
config_dict["n_gpu_layers"] = 1000000000
config_dict["model_alias"] = self.name
config_dict["model_file"] = self.path
config_dict["model_file"] = self._resolve_root_path(self.path)
model_file = config_dict.get("model_file")
model_url = config_dict.get("model_url")
model_hf_repo = config_dict.get("model_hf_repo")

View File

@ -143,6 +143,8 @@ def huggingface_loader(
if "device_map" in kwargs and "low_cpu_mem_usage" not in kwargs:
# Must set low_cpu_mem_usage to True when device_map is set
kwargs["low_cpu_mem_usage"] = True
if model_params.attn_implementation:
kwargs["attn_implementation"] = model_params.attn_implementation
model, tokenizer = _hf_try_load_default_quantization_model(
model_path, llm_adapter, device, num_gpus, model_params, kwargs

View File

@ -62,6 +62,9 @@ class VLLMDeployModelParameters(LLMDeployModelParameters):
model = data.get("path", None)
if not model:
model = data.get("name", None)
else:
# Path is specified, so we use it as the model
model = self._resolve_root_path(model)
if not model:
raise ValueError(
"Model is required, please specify the model path or name."

View File

@ -421,6 +421,10 @@ class DefaultModelWorker(ModelWorker):
span_params["messages"] = list(
map(lambda m: m.dict(), span_params["messages"])
)
if self.llm_adapter.is_reasoning_model(
self._model_params, self.model_name.lower()
):
params["is_reasoning_model"] = True
metadata = {
"is_async_func": self.support_async(),

View File

@ -613,39 +613,41 @@ class LocalWorkerManager(WorkerManager):
)
async def _start_all_worker(
self, apply_req: WorkerApplyRequest
self, apply_req: WorkerApplyRequest, parallel_num: int = 1
) -> WorkerApplyOutput:
from httpx import TimeoutException, TransportError
# TODO avoid start twice
start_time = time.time()
logger.info(f"Begin start all worker, apply_req: {apply_req}")
semaphore = asyncio.Semaphore(parallel_num)
async def _start_worker(worker_run_data: WorkerRunData):
_start_time = time.time()
info = worker_run_data._to_print_key()
out = WorkerApplyOutput("")
try:
await self.run_blocking_func(
worker_run_data.worker.start,
worker_run_data.command_args,
)
worker_run_data.stop_event.clear()
if worker_run_data.worker_params.register and self.register_func:
# Register worker to controller
await self.register_func(worker_run_data)
if (
worker_run_data.worker_params.send_heartbeat
and self.send_heartbeat_func
):
asyncio.create_task(
_async_heartbeat_sender(
worker_run_data,
worker_run_data.worker_params.heartbeat_interval,
self.send_heartbeat_func,
async with semaphore:
await self.run_blocking_func(
worker_run_data.worker.start,
worker_run_data.command_args,
)
worker_run_data.stop_event.clear()
if worker_run_data.worker_params.register and self.register_func:
# Register worker to controller
await self.register_func(worker_run_data)
if (
worker_run_data.worker_params.send_heartbeat
and self.send_heartbeat_func
):
asyncio.create_task(
_async_heartbeat_sender(
worker_run_data,
worker_run_data.worker_params.heartbeat_interval,
self.send_heartbeat_func,
)
)
)
out.message = f"{info} start successfully"
out.message = f"{info} start successfully"
except TimeoutException:
out.success = False
out.message = (

View File

@ -13,7 +13,11 @@ from dbgpt.core import ModelOutput
from dbgpt.model.adapter.llama_cpp_py_adapter import LlamaCppModelParameters
from dbgpt.model.utils.llm_utils import parse_model_request
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
from ...utils.parse_utils import (
_DEFAULT_THINK_START_TOKEN,
ParsedChatMessage,
parse_chat_message,
)
logger = logging.getLogger(__name__)
@ -113,6 +117,8 @@ class LlamaCppModel:
messages = request.to_common_messages()
repetition_penalty = float(params.get("repetition_penalty", 1.1))
top_k = int(params.get("top_k", -1)) # -1 means disable
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
# Handle truncation
completion_chunks = self.model.create_chat_completion(
messages=messages,
@ -129,6 +135,7 @@ class LlamaCppModel:
usage = None
msg = ParsedChatMessage()
finish_reason: Optional[str] = None
is_first = True
for r in completion_chunks:
if not r.get("choices"):
continue
@ -136,11 +143,16 @@ class LlamaCppModel:
if delta.get("content") is not None:
content = delta["content"]
text += content
msg, _ = parse_chat_message(
content,
extract_reasoning=True,
is_streaming=True,
streaming_state=msg,
if (
is_reasoning_model
and not text.startswith(think_start_token)
and is_first
):
text = think_start_token + "\n" + text
is_first = False
msg = parse_chat_message(
text,
extract_reasoning=is_reasoning_model,
)
finish_reason = delta.get("finish_reason")
if text:

View File

@ -17,7 +17,11 @@ from typing import Dict, Optional
from dbgpt.core import ModelOutput
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
from ...utils.parse_utils import (
_DEFAULT_THINK_START_TOKEN,
ParsedChatMessage,
parse_chat_message,
)
logger = logging.getLogger(__name__)
@ -78,7 +82,10 @@ def chat_generate_stream(
):
req = _build_chat_completion_request(params, stream=True)
text = ""
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
msg = ParsedChatMessage()
is_first = True
for r in model.stream_chat_completion(req):
if len(r.choices) == 0:
continue
@ -86,13 +93,17 @@ def chat_generate_stream(
if r.choices[0] is not None and r.choices[0].delta is None:
continue
content = r.choices[0].delta.content
if content is None:
continue
text += content
if is_reasoning_model and not text.startswith(think_start_token) and is_first:
text = think_start_token + "\n" + text
is_first = False
msg = parse_chat_message(text, extract_reasoning=is_reasoning_model)
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
if content is not None:
text += content
msg, _ = parse_chat_message(
content, extract_reasoning=True, is_streaming=True, streaming_state=msg
)
yield ModelOutput.build(
msg.content,
msg.reasoning_content,

View File

@ -6,7 +6,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
from dbgpt.core import ModelOutput
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
from ...utils.parse_utils import (
_DEFAULT_THINK_START_TOKEN,
ParsedChatMessage,
parse_chat_message,
)
logger = logging.getLogger(__name__)
@ -27,6 +31,8 @@ def huggingface_chat_generate_stream(
stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", True)
custom_stop_words = params.get("custom_stop_words", [])
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
input_ids = tokenizer(prompt).input_ids
# input_ids = input_ids.to(device)
@ -65,15 +71,23 @@ def huggingface_chat_generate_stream(
text = ""
usage = None
msg = ParsedChatMessage()
is_first = True
for new_text in streamer:
text += new_text
msg, _ = parse_chat_message(
new_text, extract_reasoning=True, is_streaming=True, streaming_state=msg
)
if custom_stop_words:
for stop_word in custom_stop_words:
if text.endswith(stop_word):
text = text[: -len(stop_word)]
if (
prompt.rstrip().endswith(think_start_token)
and is_reasoning_model
and is_first
):
text = think_start_token + "\n" + text
is_first = False
msg = parse_chat_message(text, extract_reasoning=is_reasoning_model)
yield ModelOutput.build(
msg.content,
msg.reasoning_content,

View File

@ -31,6 +31,7 @@ async def generate_stream(
best_of = params.get("best_of", None)
stop_str = params.get("stop", None)
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
# think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
stop_token_ids = params.get("stop_token_ids", None) or []
@ -104,11 +105,11 @@ async def generate_stream(
)
if text_outputs:
# Tempora
if prompt.rstrip().endswith(think_start_token):
if prompt.rstrip().endswith(think_start_token) and is_reasoning_model:
text_outputs = think_start_token + "\n" + text_outputs
msg = parse_chat_message(
text_outputs,
extract_reasoning=True,
extract_reasoning=is_reasoning_model,
)
yield ModelOutput.build(
msg.content,

View File

@ -9,13 +9,13 @@ from dbgpt.core.awel.flow import (
ResourceCategory,
auto_register_resource,
)
from dbgpt.core.interface.parameter import LLMDeployModelParameters
from dbgpt.model.proxy.base import (
AsyncGenerateStreamFunction,
GenerateStreamFunction,
ProxyLLMClient,
register_proxy_model_adapter,
)
from dbgpt.model.proxy.llms.chatgpt import OpenAICompatibleDeployModelParameters
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
from dbgpt.util.i18n_utils import _
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
show_in_ui=False,
)
@dataclass
class OllamaDeployModelParameters(OpenAICompatibleDeployModelParameters):
class OllamaDeployModelParameters(LLMDeployModelParameters):
"""Deploy model parameters for Ollama."""
provider: str = "proxy/ollama"

View File

@ -128,13 +128,13 @@ def process_streaming_chunk(
if end_marker in remaining_chunk:
end_idx = remaining_chunk.find(end_marker)
# Output reasoning content event
if end_idx > 0:
reasoning_part = remaining_chunk[:end_idx]
events.append(
StreamingEvent(type="reasoning_content", content=reasoning_part)
)
# Append reasoning content instead of replacing
msg.reasoning_content += reasoning_part
# if end_idx > 0:
reasoning_part = remaining_chunk[:end_idx]
events.append(
StreamingEvent(type="reasoning_content", content=reasoning_part)
)
# Append reasoning content instead of replacing
msg.reasoning_content += reasoning_part
# Output reasoning end event
events.append(StreamingEvent(type="reasoning_end", content=""))
@ -220,6 +220,49 @@ def process_streaming_chunk(
remaining_chunk = ""
continue
# Check for reasoning end markers without matching start markers
# This is the special case to handle
found_end_marker = False
for pattern in reasoning_patterns:
start_marker = pattern["start"]
end_marker = pattern["end"]
if end_marker in remaining_chunk and not state["in_reasoning"]:
end_idx = remaining_chunk.find(end_marker)
start_idx = 0
if start_marker in remaining_chunk:
start_idx = remaining_chunk.find(start_marker) + len(start_marker)
# This is content that should be treated as reasoning but didn't have a
# start tag
# if end_idx > 0:
reasoning_part = remaining_chunk[start_idx:end_idx]
# Clear regular content
reasoning_part = msg.content + reasoning_part
msg.content = ""
# First, emit a reasoning_start event
events.append(StreamingEvent(type="reasoning_start", content=""))
# Then emit the content as reasoning content
events.append(
StreamingEvent(type="reasoning_content", content=reasoning_part)
)
# Add to reasoning content
msg.reasoning_content += reasoning_part
# Emit the reasoning_end event
events.append(StreamingEvent(type="reasoning_end", content=""))
# Move past the end marker
remaining_chunk = remaining_chunk[end_idx + len(end_marker) :]
found_end_marker = True
state["reasoning_pattern"] = None
break
# If we found an end marker, continue to the next iteration
if found_end_marker:
continue
# Check for reasoning start markers
reasoning_start_found = False
for pattern in reasoning_patterns:
@ -228,10 +271,10 @@ def process_streaming_chunk(
start_idx = remaining_chunk.find(start_marker)
# Output regular content before the marker
if start_idx > 0:
content_part = remaining_chunk[:start_idx]
events.append(StreamingEvent(type="content", content=content_part))
msg.content += content_part
# if start_idx > 0:
content_part = remaining_chunk[:start_idx]
events.append(StreamingEvent(type="content", content=content_part))
msg.content += content_part
# Output reasoning start event
events.append(StreamingEvent(type="reasoning_start", content=""))
@ -257,12 +300,10 @@ def process_streaming_chunk(
start_idx = remaining_chunk.find(start_marker)
# Output regular content before the marker
if start_idx > 0:
content_part = remaining_chunk[:start_idx]
events.append(
StreamingEvent(type="content", content=content_part)
)
msg.content += content_part
# if start_idx > 0:
content_part = remaining_chunk[:start_idx]
events.append(StreamingEvent(type="content", content=content_part))
msg.content += content_part
# Output tool call start event
events.append(StreamingEvent(type="tool_call_start", content=""))
@ -355,6 +396,7 @@ def parse_chat_message(
reasoning_content = ""
content = input_text
# First check for the normal case with proper start and end markers
for pattern in reasoning_patterns:
start_marker = pattern["start"]
end_marker = pattern["end"]
@ -372,6 +414,44 @@ def parse_chat_message(
content = content[:start_idx] + content[end_idx + len(end_marker) :]
break
# If no reasoning content was found with the standard pattern, check for the
# special case
# where content starts with reasoning but has no start marker
if not reasoning_content:
for pattern in reasoning_patterns:
start_marker = pattern["start"]
end_marker = pattern["end"]
if end_marker in content:
# Check if this is at the beginning of the content or
# if there's no matching start marker before it
end_idx = content.find(end_marker)
start_marker = pattern["start"]
start_idx = content.find(start_marker)
# If no start marker or end marker appears before start marker
if start_idx == -1 or end_idx < start_idx:
# This is our special case - treat the content up to the end marker
# as reasoning
reasoning_content = string_strip(content[:end_idx])
# Remove reasoning part from original content
if extract_reasoning:
content = content[end_idx + len(end_marker) :]
break
elif start_marker in content:
# If there's a start marker but no end marker, treat the content
# as reasoning content
start_idx = content.find(start_marker)
reasoning_content = string_strip(
content[start_idx + len(start_marker) :]
)
# Remove reasoning part from original content
if extract_reasoning:
content = ""
break
# Parse tool calls
tool_calls_text = ""

View File

@ -1,5 +1,3 @@
import pytest
from ..parse_utils import parse_chat_message
@ -357,8 +355,8 @@ def test_streaming_mode_without_tool_calls():
# Verify final message should include tool call markers
assert (
"I will search for data.<tool▁calls▁begin>Tool call content<tool▁calls▁end>Search complete." # noqa
== msg.content
"I will search for data.<tool▁calls▁begin>Tool call content"
"<tool▁calls▁end>Search complete." == msg.content # noqa
)
assert len(msg.tool_calls) == 0
@ -449,48 +447,6 @@ def test_incomplete_markers():
assert "Tool content" in msg2.streaming_state.get("tool_call_text", "")
def test_multiple_special_sections():
"""Test handling multiple special sections"""
input_text = """<think>Reasoning content 1</think>Regular content 1
<toolcallsbegin>Tool call content<toolcallsend>
Regular content 2<think>Reasoning content 2</think>End"""
result = parse_chat_message(input_text, extract_tool_calls=True)
# Verify only first reasoning content is extracted
assert "Reasoning content 1" == result.reasoning_content
assert (
"Regular content 1\n\nRegular content 2<think>Reasoning content 2</think>End"
== result.content
)
# Use streaming processing to handle multiple reasoning parts
chunks = [
"<think>Reasoning content 1</think>Regular content 1\n",
"<tool▁calls▁begin>Tool call content<tool▁calls▁end>\n",
"Regular content 2<think>Reasoning content 2</think>End",
]
msg = None
all_events = []
for chunk in chunks:
msg, events = parse_chat_message(
chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
)
all_events.extend(events)
# Verify event sequence contains two reasoning sections
reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
assert reasoning_start_counts == 2
assert reasoning_end_counts == 2
# In streaming mode, reasoning content should accumulate
assert "Reasoning content 1Reasoning content 2" == msg.reasoning_content
def test_custom_streaming_patterns():
"""Test custom streaming pattern markers"""
custom_reasoning = [{"start": "{{thinking}}", "end": "{{/thinking}}"}]
@ -530,41 +486,239 @@ def test_custom_streaming_patterns():
assert "tool_call_end" in event_types
def test_alternating_reasoning_and_tool_calls():
"""Test alternating between reasoning and tool calls in a single message"""
# Use streaming to capture all sections
def test_missing_start_token_non_streaming():
"""Test parsing messages with missing start token but having end token
(non-streaming mode)
"""
input_text = """Model reasoning content without start token.
</think>
This is the regular content part."""
result = parse_chat_message(input_text, extract_reasoning=True)
assert "This is the regular content part." == result.content
assert "Model reasoning content without start token." == result.reasoning_content
def test_missing_start_token_streaming():
"""Test parsing messages with missing start token but having end token
(streaming mode)
"""
chunks = [
"<think>First reasoning block</think>Content 1\n",
"<tool▁calls▁begin>Tool call 1<tool▁calls▁end>\n",
"Content 2<think>Second reasoning block</think>\n",
"<tool▁calls▁begin>Tool call 2<tool▁calls▁end>\n",
"Final content",
"Model reasoning content ",
"without start token.</think>",
"This is the regular content part.",
]
msg = None
all_events = []
for chunk in chunks:
msg, events = parse_chat_message(
chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
)
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
all_events.extend(events)
# Verify content is parsed correctly - note the double newlines
assert "Content 1\n\nContent 2\n\nFinal content" == msg.content
assert "First reasoning blockSecond reasoning block" == msg.reasoning_content
# Verify final message - match the expected result in test
assert "This is the regular content part." == msg.content
assert "Model reasoning content without start token." == msg.reasoning_content
# Count events by type
event_counts = {}
for e in all_events:
event_counts[e.type] = event_counts.get(e.type, 0) + 1
assert event_counts.get("reasoning_start", 0) == 2
assert event_counts.get("reasoning_end", 0) == 2
assert event_counts.get("tool_call_start", 0) == 2
assert event_counts.get("tool_call_end", 0) == 2
# Verify event sequence contains correct reasoning events
event_types = [e.type for e in all_events]
assert "reasoning_start" in event_types
assert "reasoning_content" in event_types
assert "reasoning_end" in event_types
if __name__ == "__main__":
# Run tests
pytest.main(["-v", "test_parse_utils.py"])
def test_missing_start_token_deepseek_chinese():
"""Test the DeepSeek example with Chinese content missing start token"""
input_text = """您好我是由中国的深度求索DeepSeek公司开发的智能助手\
DeepSeek-R1有关模型和产品的详细内容请参考官方文档
</think>
您好我是由中国的深度求索DeepSeek公司开发的智能助手DeepSeek-R1有关模型\
和产品的详细内容请参考官方文档"""
result = parse_chat_message(input_text, extract_reasoning=True)
assert (
"您好我是由中国的深度求索DeepSeek公司开发的智能助手DeepSeek-R1。"
"有关模型和产品的详细内容请参考官方文档。" == result.content
)
assert (
"您好我是由中国的深度求索DeepSeek公司开发的智能助手DeepSeek-R1。"
"有关模型和产品的详细内容请参考官方文档。" == result.reasoning_content
)
def test_multiple_missing_start_tokens():
"""Test multiple occurrences of missing start tokens in the same message"""
input_text = """First reasoning section.
</think>
Some regular content.
Second reasoning section.
</reasoning>
More regular content."""
result = parse_chat_message(input_text, extract_reasoning=True)
# Note: In non-streaming mode, only the first matching reasoning content is
# extracted
assert (
"Some regular content.\nSecond reasoning section.\n</reasoning>\nMore regular "
"content." == result.content
)
assert "First reasoning section." == result.reasoning_content
# Use streaming to capture all sections
chunks = [
"First reasoning section.\n</think>\n",
"Some regular content.\n",
"Second reasoning section.\n</reasoning>\n",
"More regular content.",
]
msg = None
all_events = []
for chunk in chunks:
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
all_events.extend(events)
# In streaming mode, reasoning content should match the expected format
assert (
"First reasoning section.\n\nSome regular content.\nSecond reasoning section.\n"
== msg.reasoning_content
)
# Verify event sequence contains two reasoning sections
reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
assert reasoning_start_counts == 2
assert reasoning_end_counts == 2
def test_missing_start_token_with_tools():
"""Test missing reasoning start token with tool calls"""
input_text = """Analyzing user request to query weather information.
</think>
I'll look up the weather data for you.
<toolcallsbegin>
<toolcallbegin>function<toolsep>get_weather
```json
{
"location": "Beijing",
"date": "2023-05-20"
}
```
<toolcallend>
<toolcallsend>"""
result = parse_chat_message(
input_text, extract_reasoning=True, extract_tool_calls=True
)
assert "I'll look up the weather data for you." in result.content
assert (
"Analyzing user request to query weather information."
== result.reasoning_content
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_weather"
assert result.tool_calls[0]["arguments"]["location"] == "Beijing"
def test_mixed_language_missing_start_token():
"""Test mixed Chinese and English content with missing start token"""
input_text = """这是一段中英文混合的思考内容 with both languages mixed together.
</think>
Here's the regular content with 中文 mixed in."""
result = parse_chat_message(input_text, extract_reasoning=True)
assert "Here's the regular content with 中文 mixed in." == result.content
assert (
"这是一段中英文混合的思考内容 with both languages mixed together."
== result.reasoning_content
)
def test_streaming_mixed_language_missing_start():
"""Test streaming mixed language content with missing start token"""
chunks = [
"Analysis 分析: The user needs ",
"information about 关于天气的信息。</reasoning>",
"I'll provide weather information 我将提供天气信息。",
]
msg = None
all_events = []
for chunk in chunks:
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
all_events.extend(events)
# Verify final message includes mixed language content properly parsed
assert "I'll provide weather information 我将提供天气信息。" == msg.content
assert (
"Analysis 分析: The user needs information about 关于天气的信息。"
== msg.reasoning_content
)
# Verify events sequence
reasoning_events = [e for e in all_events if e.type.startswith("reasoning_")]
assert len(reasoning_events) >= 3 # At least start, content, and end events
def test_chinese_pattern_missing_start():
"""Test Chinese pattern with missing start token"""
input_text = """这里是模型的思考内容,但是没有开始标记。
</思考>
这是正常的响应内容"""
result = parse_chat_message(input_text, extract_reasoning=True)
assert "这是正常的响应内容。" == result.content
assert "这里是模型的思考内容,但是没有开始标记。" == result.reasoning_content
#
# def test_multiple_special_sections():
# """Test handling multiple special sections"""
# input_text = """<think>Reasoning content 1</think>Regular content 1
# <tool▁calls▁begin>Tool call content<tool▁calls▁end>
# Regular content 2<think>Reasoning content 2</think>End"""
#
# result = parse_chat_message(input_text, extract_tool_calls=True)
#
# # Verify only first reasoning content is extracted
# assert "Reasoning content 1" == result.reasoning_content
# assert (
# "Regular content 1\n\nRegular content 2<think>Reasoning content 2</think>End"
# == result.content
# )
#
# # Use streaming processing to handle multiple reasoning parts
# chunks = [
# "<think>Reasoning content 1</think>Regular content 1\n",
# "<tool▁calls▁begin>Tool call content<tool▁calls▁end>\n",
# "Regular content 2<think>Reasoning content 2</think>End",
# ]
#
# msg = None
# all_events = []
#
# for chunk in chunks:
# msg, events = parse_chat_message(
# chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
# )
# all_events.extend(events)
#
# # Verify event sequence contains two reasoning sections
# reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
# reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
#
# assert reasoning_start_counts == 2
# assert reasoning_end_counts == 2
#
# # In streaming mode, reasoning content should match the expected format
# assert "Reasoning content 1Reasoning content 2" == msg.reasoning_content

View File

@ -113,7 +113,7 @@ class HFEmbeddingDeployModelParameters(EmbeddingDeployModelParameters):
@property
def real_provider_model_name(self) -> str:
"""Get the real provider model name."""
return self.path or self.name
return self.real_model_path or self.name
@property
def real_model_path(self) -> Optional[str]:

View File

@ -40,8 +40,7 @@ class ClickhouseParameters(BaseDatasourceParameters):
user: str = field(metadata={"help": _("Database user to connect")})
database: str = field(metadata={"help": _("Database name")})
engine: str = field(
default="MergeTree",
metadata={"help": _("Storage engine, e.g., MergeTree")}
default="MergeTree", metadata={"help": _("Storage engine, e.g., MergeTree")}
)
password: str = field(
default="${env:DBGPT_DB_PASSWORD}",

View File

@ -13,16 +13,22 @@ readme = "README.md"
requires-python = ">= 3.10"
[tool.uv.sources]
dbgpt-accelerator = { workspace = true }
dbgpt = { workspace = true }
dbgpt-client = { workspace = true }
dbgpt-ext = { workspace = true }
dbgpt-serve = { workspace = true }
dbgpt-app = { workspace = true }
dbgpt-acc-auto = { workspace = true }
dbgpt-acc-flash-attn = { workspace = true }
[tool.uv.workspace]
members = [
"packages/dbgpt-*"
"packages/dbgpt-app",
"packages/dbgpt-client",
"packages/dbgpt-core",
"packages/dbgpt-ext",
"packages/dbgpt-serve",
"packages/dbgpt-accelerator/*"
]
[tool.uv]
@ -65,4 +71,4 @@ select = ["E", "F", "I"]
[tool.ruff.isort]
# Specify the local modules (first-party)
known-first-party = ["dbgpt", "dbgpt_accelerator", "dbgpt_client", "dbgpt_ext", "dbgpt_serve", "dbgpt_app"]
known-first-party = ["dbgpt", "dbgpt_acc_auto", "dbgpt_client", "dbgpt_ext", "dbgpt_serve", "dbgpt_app"]

3304
uv.lock

File diff suppressed because one or more lines are too long