mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 11:23:40 +00:00
fix(model): Fix reasoning output bug
This commit is contained in:
parent
e3a25de7f7
commit
c8e252c4de
@ -34,4 +34,4 @@ provider = "hf"
|
|||||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
# uncomment the following line to specify the model path in the local file system
|
# uncomment the following line to specify the model path in the local file system
|
||||||
# path = "the-model-path-in-the-local-file-system"
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
path = "models/BAAI/glm-4-9b-chat-hf"
|
path = "models/BAAI/bge-large-zh-v1.5"
|
||||||
|
@ -285,8 +285,23 @@ uv run dbgpt start webserver --config configs/dbgpt-local-vllm.toml
|
|||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
||||||
|
|
||||||
|
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use uv to install dependencies needed for llama-cpp
|
||||||
|
# Install core dependencies and select desired extensions
|
||||||
|
CMAKE_ARGS="-DGGML_CUDA=ON" uv sync --all-packages \
|
||||||
|
--extra "base" \
|
||||||
|
--extra "llama_cpp" \
|
||||||
|
--extra "rag" \
|
||||||
|
--extra "storage_chromadb" \
|
||||||
|
--extra "quant_bnb" \
|
||||||
|
--extra "dbgpts"
|
||||||
|
```
|
||||||
|
|
||||||
|
Otherwise, run the following command to install dependencies without CUDA support.
|
||||||
```bash
|
```bash
|
||||||
# Use uv to install dependencies needed for llama-cpp
|
# Use uv to install dependencies needed for llama-cpp
|
||||||
# Install core dependencies and select desired extensions
|
# Install core dependencies and select desired extensions
|
||||||
|
4
packages/dbgpt-accelerator/dbgpt-acc-auto/README.md
Normal file
4
packages/dbgpt-accelerator/dbgpt-acc-auto/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# DB-GPT Accelerator Module
|
||||||
|
|
||||||
|
Building across multiple platforms and hardware is complex, and the DB-GPT Accelerator aims to provide compatibility handling for this, offering as consistent an interface as possible for other core models.
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dbgpt-accelerator"
|
name = "dbgpt-acc-auto"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
authors = [
|
authors = [
|
||||||
@ -16,21 +16,6 @@ Documentation = "http://docs.dbgpt.cn/docs/overview"
|
|||||||
Repository = "https://github.com/eosphoros-ai/DB-GPT.git"
|
Repository = "https://github.com/eosphoros-ai/DB-GPT.git"
|
||||||
Issues = "https://github.com/eosphoros-ai/DB-GPT/issues"
|
Issues = "https://github.com/eosphoros-ai/DB-GPT/issues"
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["hatchling"]
|
|
||||||
build-backend = "hatchling.build"
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
|
||||||
packages = ["src/dbgpt_accelerator"]
|
|
||||||
exclude = [
|
|
||||||
"src/dbgpt_accelerator/**/tests",
|
|
||||||
"src/dbgpt_accelerator/**/tests/*",
|
|
||||||
"src/dbgpt_accelerator/tests",
|
|
||||||
"src/dbgpt_accelerator/tests/*",
|
|
||||||
"src/dbgpt_accelerator/**/examples",
|
|
||||||
"src/dbgpt_accelerator/**/examples/*"
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
# Auto install dependencies
|
# Auto install dependencies
|
||||||
auto = [
|
auto = [
|
||||||
@ -76,10 +61,10 @@ vllm = [
|
|||||||
# Just support GPU version on Linux
|
# Just support GPU version on Linux
|
||||||
"vllm>=0.7.0; sys_platform == 'linux'",
|
"vllm>=0.7.0; sys_platform == 'linux'",
|
||||||
]
|
]
|
||||||
#vllm_pascal = [
|
# vllm_pascal = [
|
||||||
# # https://github.com/sasha0552/pascal-pkgs-ci
|
# # https://github.com/sasha0552/pascal-pkgs-ci
|
||||||
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
||||||
#]
|
# ]
|
||||||
quant_bnb = [
|
quant_bnb = [
|
||||||
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
|
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
|
||||||
"accelerate"
|
"accelerate"
|
||||||
@ -103,6 +88,10 @@ quant_gptq = [
|
|||||||
"optimum",
|
"optimum",
|
||||||
"auto-gptq",
|
"auto-gptq",
|
||||||
]
|
]
|
||||||
|
flash_attn = [
|
||||||
|
# "torch>=2.2.1",
|
||||||
|
"dbgpt-acc-flash-attn"
|
||||||
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
auto = [
|
auto = [
|
@ -0,0 +1,3 @@
|
|||||||
|
# DB-GPT-Accelerator for Flash Attention
|
||||||
|
|
||||||
|
Wrapper for the Flash Attention module in the DB-GPT-Accelerator.
|
@ -0,0 +1,24 @@
|
|||||||
|
# Install the flash-attn package for uv
|
||||||
|
# https://github.com/astral-sh/uv/issues/2252#issuecomment-2624150395
|
||||||
|
[project]
|
||||||
|
name = "dbgpt-acc-flash-attn"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = []
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
build = [
|
||||||
|
"setuptools>=75.8.0",
|
||||||
|
]
|
||||||
|
direct = [
|
||||||
|
"torch>=2.2.1",
|
||||||
|
]
|
||||||
|
main = [
|
||||||
|
"flash-attn>=2.5.8",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
default-groups = ["build", "direct", "main"]
|
||||||
|
no-build-isolation-package = ["flash-attn"]
|
@ -10,7 +10,7 @@ readme = "README.md"
|
|||||||
requires-python = ">= 3.10"
|
requires-python = ">= 3.10"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dbgpt-accelerator",
|
"dbgpt-acc-auto",
|
||||||
"dbgpt",
|
"dbgpt",
|
||||||
"dbgpt-ext",
|
"dbgpt-ext",
|
||||||
"dbgpt-serve",
|
"dbgpt-serve",
|
||||||
|
@ -102,6 +102,15 @@ class LLMDeployModelParameters(BaseDeployModelParameters, RegisterParameters):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
reasoning_model: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Whether the model is a reasoning model. If None, it is "
|
||||||
|
"automatically determined from model."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def real_provider_model_name(self) -> str:
|
def real_provider_model_name(self) -> str:
|
||||||
@ -202,8 +211,10 @@ class BitsandbytesQuantization(BaseHFQuantization):
|
|||||||
real_cls = cls
|
real_cls = cls
|
||||||
if load_in_8bits:
|
if load_in_8bits:
|
||||||
real_cls = BitsandbytesQuantization8bits
|
real_cls = BitsandbytesQuantization8bits
|
||||||
|
data["type"] = BitsandbytesQuantization8bits.__type__
|
||||||
if load_in_4bits:
|
if load_in_4bits:
|
||||||
real_cls = BitsandbytesQuantization4bits
|
real_cls = BitsandbytesQuantization4bits
|
||||||
|
data["type"] = BitsandbytesQuantization4bits.__type__
|
||||||
real_data = prepare_data_func(real_cls, data)
|
real_data = prepare_data_func(real_cls, data)
|
||||||
return real_cls(**real_data)
|
return real_cls(**real_data)
|
||||||
|
|
||||||
|
@ -251,6 +251,27 @@ class LLMModelAdapter(ABC):
|
|||||||
"""Load the model and tokenizer according to the given parameters"""
|
"""Load the model and tokenizer according to the given parameters"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_reasoning_model(
|
||||||
|
self,
|
||||||
|
deploy_model_params: LLMDeployModelParameters,
|
||||||
|
lower_model_name_or_path: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Whether the model is a reasoning model"""
|
||||||
|
if (
|
||||||
|
deploy_model_params.reasoning_model is not None
|
||||||
|
and deploy_model_params.reasoning_model
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return (
|
||||||
|
lower_model_name_or_path
|
||||||
|
and "deepseek" in lower_model_name_or_path
|
||||||
|
and (
|
||||||
|
"r1" in lower_model_name_or_path
|
||||||
|
or "reasoning" in lower_model_name_or_path
|
||||||
|
or "reasoner" in lower_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def support_async(self) -> bool:
|
def support_async(self) -> bool:
|
||||||
"""Whether the loaded model supports asynchronous calls"""
|
"""Whether the loaded model supports asynchronous calls"""
|
||||||
return False
|
return False
|
||||||
|
@ -88,6 +88,15 @@ class HFLLMDeployModelParameters(LLMDeployModelParameters):
|
|||||||
"valid_values": ["auto", "float16", "bfloat16", "float", "float32"],
|
"valid_values": ["auto", "float16", "bfloat16", "float", "float32"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
attn_implementation: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"The attention implementation, only valid in multi-GPU configuration"
|
||||||
|
),
|
||||||
|
"valid_values": ["flash_attention_2"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def real_model_path(self) -> Optional[str]:
|
def real_model_path(self) -> Optional[str]:
|
||||||
|
@ -301,15 +301,15 @@ class LlamaServerParameters(LLMDeployModelParameters):
|
|||||||
config_dict[fd.name] = curr_config[fd.name]
|
config_dict[fd.name] = curr_config[fd.name]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"device" in config_dict
|
self.real_device
|
||||||
and config_dict["device"] == "cuda"
|
and self.real_device == "cuda"
|
||||||
and ("n_gpu_layers" not in config_dict or not config_dict["n_gpu_layers"])
|
and ("n_gpu_layers" not in config_dict or not config_dict["n_gpu_layers"])
|
||||||
):
|
):
|
||||||
# Set n_gpu_layers to a large number to use all layers
|
# Set n_gpu_layers to a large number to use all layers
|
||||||
logger.info("Set n_gpu_layers to a large number to use all layers")
|
logger.info("Set n_gpu_layers to a large number to use all layers")
|
||||||
config_dict["n_gpu_layers"] = 1000000000
|
config_dict["n_gpu_layers"] = 1000000000
|
||||||
config_dict["model_alias"] = self.name
|
config_dict["model_alias"] = self.name
|
||||||
config_dict["model_file"] = self.path
|
config_dict["model_file"] = self._resolve_root_path(self.path)
|
||||||
model_file = config_dict.get("model_file")
|
model_file = config_dict.get("model_file")
|
||||||
model_url = config_dict.get("model_url")
|
model_url = config_dict.get("model_url")
|
||||||
model_hf_repo = config_dict.get("model_hf_repo")
|
model_hf_repo = config_dict.get("model_hf_repo")
|
||||||
|
@ -143,6 +143,8 @@ def huggingface_loader(
|
|||||||
if "device_map" in kwargs and "low_cpu_mem_usage" not in kwargs:
|
if "device_map" in kwargs and "low_cpu_mem_usage" not in kwargs:
|
||||||
# Must set low_cpu_mem_usage to True when device_map is set
|
# Must set low_cpu_mem_usage to True when device_map is set
|
||||||
kwargs["low_cpu_mem_usage"] = True
|
kwargs["low_cpu_mem_usage"] = True
|
||||||
|
if model_params.attn_implementation:
|
||||||
|
kwargs["attn_implementation"] = model_params.attn_implementation
|
||||||
|
|
||||||
model, tokenizer = _hf_try_load_default_quantization_model(
|
model, tokenizer = _hf_try_load_default_quantization_model(
|
||||||
model_path, llm_adapter, device, num_gpus, model_params, kwargs
|
model_path, llm_adapter, device, num_gpus, model_params, kwargs
|
||||||
|
@ -62,6 +62,9 @@ class VLLMDeployModelParameters(LLMDeployModelParameters):
|
|||||||
model = data.get("path", None)
|
model = data.get("path", None)
|
||||||
if not model:
|
if not model:
|
||||||
model = data.get("name", None)
|
model = data.get("name", None)
|
||||||
|
else:
|
||||||
|
# Path is specified, so we use it as the model
|
||||||
|
model = self._resolve_root_path(model)
|
||||||
if not model:
|
if not model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Model is required, please specify the model path or name."
|
"Model is required, please specify the model path or name."
|
||||||
|
@ -421,6 +421,10 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
span_params["messages"] = list(
|
span_params["messages"] = list(
|
||||||
map(lambda m: m.dict(), span_params["messages"])
|
map(lambda m: m.dict(), span_params["messages"])
|
||||||
)
|
)
|
||||||
|
if self.llm_adapter.is_reasoning_model(
|
||||||
|
self._model_params, self.model_name.lower()
|
||||||
|
):
|
||||||
|
params["is_reasoning_model"] = True
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"is_async_func": self.support_async(),
|
"is_async_func": self.support_async(),
|
||||||
|
@ -613,39 +613,41 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _start_all_worker(
|
async def _start_all_worker(
|
||||||
self, apply_req: WorkerApplyRequest
|
self, apply_req: WorkerApplyRequest, parallel_num: int = 1
|
||||||
) -> WorkerApplyOutput:
|
) -> WorkerApplyOutput:
|
||||||
from httpx import TimeoutException, TransportError
|
from httpx import TimeoutException, TransportError
|
||||||
|
|
||||||
# TODO avoid start twice
|
# TODO avoid start twice
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||||
|
semaphore = asyncio.Semaphore(parallel_num)
|
||||||
|
|
||||||
async def _start_worker(worker_run_data: WorkerRunData):
|
async def _start_worker(worker_run_data: WorkerRunData):
|
||||||
_start_time = time.time()
|
_start_time = time.time()
|
||||||
info = worker_run_data._to_print_key()
|
info = worker_run_data._to_print_key()
|
||||||
out = WorkerApplyOutput("")
|
out = WorkerApplyOutput("")
|
||||||
try:
|
try:
|
||||||
await self.run_blocking_func(
|
async with semaphore:
|
||||||
worker_run_data.worker.start,
|
await self.run_blocking_func(
|
||||||
worker_run_data.command_args,
|
worker_run_data.worker.start,
|
||||||
)
|
worker_run_data.command_args,
|
||||||
worker_run_data.stop_event.clear()
|
)
|
||||||
if worker_run_data.worker_params.register and self.register_func:
|
worker_run_data.stop_event.clear()
|
||||||
# Register worker to controller
|
if worker_run_data.worker_params.register and self.register_func:
|
||||||
await self.register_func(worker_run_data)
|
# Register worker to controller
|
||||||
if (
|
await self.register_func(worker_run_data)
|
||||||
worker_run_data.worker_params.send_heartbeat
|
if (
|
||||||
and self.send_heartbeat_func
|
worker_run_data.worker_params.send_heartbeat
|
||||||
):
|
and self.send_heartbeat_func
|
||||||
asyncio.create_task(
|
):
|
||||||
_async_heartbeat_sender(
|
asyncio.create_task(
|
||||||
worker_run_data,
|
_async_heartbeat_sender(
|
||||||
worker_run_data.worker_params.heartbeat_interval,
|
worker_run_data,
|
||||||
self.send_heartbeat_func,
|
worker_run_data.worker_params.heartbeat_interval,
|
||||||
|
self.send_heartbeat_func,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
out.message = f"{info} start successfully"
|
||||||
out.message = f"{info} start successfully"
|
|
||||||
except TimeoutException:
|
except TimeoutException:
|
||||||
out.success = False
|
out.success = False
|
||||||
out.message = (
|
out.message = (
|
||||||
|
@ -13,7 +13,11 @@ from dbgpt.core import ModelOutput
|
|||||||
from dbgpt.model.adapter.llama_cpp_py_adapter import LlamaCppModelParameters
|
from dbgpt.model.adapter.llama_cpp_py_adapter import LlamaCppModelParameters
|
||||||
from dbgpt.model.utils.llm_utils import parse_model_request
|
from dbgpt.model.utils.llm_utils import parse_model_request
|
||||||
|
|
||||||
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
|
from ...utils.parse_utils import (
|
||||||
|
_DEFAULT_THINK_START_TOKEN,
|
||||||
|
ParsedChatMessage,
|
||||||
|
parse_chat_message,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -113,6 +117,8 @@ class LlamaCppModel:
|
|||||||
messages = request.to_common_messages()
|
messages = request.to_common_messages()
|
||||||
repetition_penalty = float(params.get("repetition_penalty", 1.1))
|
repetition_penalty = float(params.get("repetition_penalty", 1.1))
|
||||||
top_k = int(params.get("top_k", -1)) # -1 means disable
|
top_k = int(params.get("top_k", -1)) # -1 means disable
|
||||||
|
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||||
|
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||||
# Handle truncation
|
# Handle truncation
|
||||||
completion_chunks = self.model.create_chat_completion(
|
completion_chunks = self.model.create_chat_completion(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@ -129,6 +135,7 @@ class LlamaCppModel:
|
|||||||
usage = None
|
usage = None
|
||||||
msg = ParsedChatMessage()
|
msg = ParsedChatMessage()
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
|
is_first = True
|
||||||
for r in completion_chunks:
|
for r in completion_chunks:
|
||||||
if not r.get("choices"):
|
if not r.get("choices"):
|
||||||
continue
|
continue
|
||||||
@ -136,11 +143,16 @@ class LlamaCppModel:
|
|||||||
if delta.get("content") is not None:
|
if delta.get("content") is not None:
|
||||||
content = delta["content"]
|
content = delta["content"]
|
||||||
text += content
|
text += content
|
||||||
msg, _ = parse_chat_message(
|
if (
|
||||||
content,
|
is_reasoning_model
|
||||||
extract_reasoning=True,
|
and not text.startswith(think_start_token)
|
||||||
is_streaming=True,
|
and is_first
|
||||||
streaming_state=msg,
|
):
|
||||||
|
text = think_start_token + "\n" + text
|
||||||
|
is_first = False
|
||||||
|
msg = parse_chat_message(
|
||||||
|
text,
|
||||||
|
extract_reasoning=is_reasoning_model,
|
||||||
)
|
)
|
||||||
finish_reason = delta.get("finish_reason")
|
finish_reason = delta.get("finish_reason")
|
||||||
if text:
|
if text:
|
||||||
|
@ -17,7 +17,11 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from dbgpt.core import ModelOutput
|
from dbgpt.core import ModelOutput
|
||||||
|
|
||||||
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
|
from ...utils.parse_utils import (
|
||||||
|
_DEFAULT_THINK_START_TOKEN,
|
||||||
|
ParsedChatMessage,
|
||||||
|
parse_chat_message,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -78,7 +82,10 @@ def chat_generate_stream(
|
|||||||
):
|
):
|
||||||
req = _build_chat_completion_request(params, stream=True)
|
req = _build_chat_completion_request(params, stream=True)
|
||||||
text = ""
|
text = ""
|
||||||
|
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||||
|
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||||
msg = ParsedChatMessage()
|
msg = ParsedChatMessage()
|
||||||
|
is_first = True
|
||||||
for r in model.stream_chat_completion(req):
|
for r in model.stream_chat_completion(req):
|
||||||
if len(r.choices) == 0:
|
if len(r.choices) == 0:
|
||||||
continue
|
continue
|
||||||
@ -86,13 +93,17 @@ def chat_generate_stream(
|
|||||||
if r.choices[0] is not None and r.choices[0].delta is None:
|
if r.choices[0] is not None and r.choices[0].delta is None:
|
||||||
continue
|
continue
|
||||||
content = r.choices[0].delta.content
|
content = r.choices[0].delta.content
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
text += content
|
||||||
|
if is_reasoning_model and not text.startswith(think_start_token) and is_first:
|
||||||
|
text = think_start_token + "\n" + text
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
msg = parse_chat_message(text, extract_reasoning=is_reasoning_model)
|
||||||
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
|
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
|
||||||
|
|
||||||
if content is not None:
|
|
||||||
text += content
|
|
||||||
msg, _ = parse_chat_message(
|
|
||||||
content, extract_reasoning=True, is_streaming=True, streaming_state=msg
|
|
||||||
)
|
|
||||||
yield ModelOutput.build(
|
yield ModelOutput.build(
|
||||||
msg.content,
|
msg.content,
|
||||||
msg.reasoning_content,
|
msg.reasoning_content,
|
||||||
|
@ -6,7 +6,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
|
|||||||
|
|
||||||
from dbgpt.core import ModelOutput
|
from dbgpt.core import ModelOutput
|
||||||
|
|
||||||
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
|
from ...utils.parse_utils import (
|
||||||
|
_DEFAULT_THINK_START_TOKEN,
|
||||||
|
ParsedChatMessage,
|
||||||
|
parse_chat_message,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -27,6 +31,8 @@ def huggingface_chat_generate_stream(
|
|||||||
stop_token_ids = params.get("stop_token_ids", [])
|
stop_token_ids = params.get("stop_token_ids", [])
|
||||||
do_sample = params.get("do_sample", True)
|
do_sample = params.get("do_sample", True)
|
||||||
custom_stop_words = params.get("custom_stop_words", [])
|
custom_stop_words = params.get("custom_stop_words", [])
|
||||||
|
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||||
|
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||||
|
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
# input_ids = input_ids.to(device)
|
# input_ids = input_ids.to(device)
|
||||||
@ -65,15 +71,23 @@ def huggingface_chat_generate_stream(
|
|||||||
text = ""
|
text = ""
|
||||||
usage = None
|
usage = None
|
||||||
msg = ParsedChatMessage()
|
msg = ParsedChatMessage()
|
||||||
|
is_first = True
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
text += new_text
|
text += new_text
|
||||||
msg, _ = parse_chat_message(
|
|
||||||
new_text, extract_reasoning=True, is_streaming=True, streaming_state=msg
|
|
||||||
)
|
|
||||||
if custom_stop_words:
|
if custom_stop_words:
|
||||||
for stop_word in custom_stop_words:
|
for stop_word in custom_stop_words:
|
||||||
if text.endswith(stop_word):
|
if text.endswith(stop_word):
|
||||||
text = text[: -len(stop_word)]
|
text = text[: -len(stop_word)]
|
||||||
|
|
||||||
|
if (
|
||||||
|
prompt.rstrip().endswith(think_start_token)
|
||||||
|
and is_reasoning_model
|
||||||
|
and is_first
|
||||||
|
):
|
||||||
|
text = think_start_token + "\n" + text
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
msg = parse_chat_message(text, extract_reasoning=is_reasoning_model)
|
||||||
yield ModelOutput.build(
|
yield ModelOutput.build(
|
||||||
msg.content,
|
msg.content,
|
||||||
msg.reasoning_content,
|
msg.reasoning_content,
|
||||||
|
@ -31,6 +31,7 @@ async def generate_stream(
|
|||||||
best_of = params.get("best_of", None)
|
best_of = params.get("best_of", None)
|
||||||
stop_str = params.get("stop", None)
|
stop_str = params.get("stop", None)
|
||||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||||
|
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||||
# think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
# think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
||||||
|
|
||||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||||
@ -104,11 +105,11 @@ async def generate_stream(
|
|||||||
)
|
)
|
||||||
if text_outputs:
|
if text_outputs:
|
||||||
# Tempora
|
# Tempora
|
||||||
if prompt.rstrip().endswith(think_start_token):
|
if prompt.rstrip().endswith(think_start_token) and is_reasoning_model:
|
||||||
text_outputs = think_start_token + "\n" + text_outputs
|
text_outputs = think_start_token + "\n" + text_outputs
|
||||||
msg = parse_chat_message(
|
msg = parse_chat_message(
|
||||||
text_outputs,
|
text_outputs,
|
||||||
extract_reasoning=True,
|
extract_reasoning=is_reasoning_model,
|
||||||
)
|
)
|
||||||
yield ModelOutput.build(
|
yield ModelOutput.build(
|
||||||
msg.content,
|
msg.content,
|
||||||
|
@ -128,13 +128,13 @@ def process_streaming_chunk(
|
|||||||
if end_marker in remaining_chunk:
|
if end_marker in remaining_chunk:
|
||||||
end_idx = remaining_chunk.find(end_marker)
|
end_idx = remaining_chunk.find(end_marker)
|
||||||
# Output reasoning content event
|
# Output reasoning content event
|
||||||
if end_idx > 0:
|
# if end_idx > 0:
|
||||||
reasoning_part = remaining_chunk[:end_idx]
|
reasoning_part = remaining_chunk[:end_idx]
|
||||||
events.append(
|
events.append(
|
||||||
StreamingEvent(type="reasoning_content", content=reasoning_part)
|
StreamingEvent(type="reasoning_content", content=reasoning_part)
|
||||||
)
|
)
|
||||||
# Append reasoning content instead of replacing
|
# Append reasoning content instead of replacing
|
||||||
msg.reasoning_content += reasoning_part
|
msg.reasoning_content += reasoning_part
|
||||||
|
|
||||||
# Output reasoning end event
|
# Output reasoning end event
|
||||||
events.append(StreamingEvent(type="reasoning_end", content=""))
|
events.append(StreamingEvent(type="reasoning_end", content=""))
|
||||||
@ -220,6 +220,49 @@ def process_streaming_chunk(
|
|||||||
remaining_chunk = ""
|
remaining_chunk = ""
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check for reasoning end markers without matching start markers
|
||||||
|
# This is the special case to handle
|
||||||
|
found_end_marker = False
|
||||||
|
for pattern in reasoning_patterns:
|
||||||
|
start_marker = pattern["start"]
|
||||||
|
end_marker = pattern["end"]
|
||||||
|
if end_marker in remaining_chunk and not state["in_reasoning"]:
|
||||||
|
end_idx = remaining_chunk.find(end_marker)
|
||||||
|
start_idx = 0
|
||||||
|
if start_marker in remaining_chunk:
|
||||||
|
start_idx = remaining_chunk.find(start_marker) + len(start_marker)
|
||||||
|
|
||||||
|
# This is content that should be treated as reasoning but didn't have a
|
||||||
|
# start tag
|
||||||
|
# if end_idx > 0:
|
||||||
|
reasoning_part = remaining_chunk[start_idx:end_idx]
|
||||||
|
# Clear regular content
|
||||||
|
reasoning_part = msg.content + reasoning_part
|
||||||
|
msg.content = ""
|
||||||
|
|
||||||
|
# First, emit a reasoning_start event
|
||||||
|
events.append(StreamingEvent(type="reasoning_start", content=""))
|
||||||
|
|
||||||
|
# Then emit the content as reasoning content
|
||||||
|
events.append(
|
||||||
|
StreamingEvent(type="reasoning_content", content=reasoning_part)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to reasoning content
|
||||||
|
msg.reasoning_content += reasoning_part
|
||||||
|
|
||||||
|
# Emit the reasoning_end event
|
||||||
|
events.append(StreamingEvent(type="reasoning_end", content=""))
|
||||||
|
# Move past the end marker
|
||||||
|
remaining_chunk = remaining_chunk[end_idx + len(end_marker) :]
|
||||||
|
found_end_marker = True
|
||||||
|
state["reasoning_pattern"] = None
|
||||||
|
break
|
||||||
|
|
||||||
|
# If we found an end marker, continue to the next iteration
|
||||||
|
if found_end_marker:
|
||||||
|
continue
|
||||||
|
|
||||||
# Check for reasoning start markers
|
# Check for reasoning start markers
|
||||||
reasoning_start_found = False
|
reasoning_start_found = False
|
||||||
for pattern in reasoning_patterns:
|
for pattern in reasoning_patterns:
|
||||||
@ -228,10 +271,10 @@ def process_streaming_chunk(
|
|||||||
start_idx = remaining_chunk.find(start_marker)
|
start_idx = remaining_chunk.find(start_marker)
|
||||||
|
|
||||||
# Output regular content before the marker
|
# Output regular content before the marker
|
||||||
if start_idx > 0:
|
# if start_idx > 0:
|
||||||
content_part = remaining_chunk[:start_idx]
|
content_part = remaining_chunk[:start_idx]
|
||||||
events.append(StreamingEvent(type="content", content=content_part))
|
events.append(StreamingEvent(type="content", content=content_part))
|
||||||
msg.content += content_part
|
msg.content += content_part
|
||||||
|
|
||||||
# Output reasoning start event
|
# Output reasoning start event
|
||||||
events.append(StreamingEvent(type="reasoning_start", content=""))
|
events.append(StreamingEvent(type="reasoning_start", content=""))
|
||||||
@ -257,12 +300,10 @@ def process_streaming_chunk(
|
|||||||
start_idx = remaining_chunk.find(start_marker)
|
start_idx = remaining_chunk.find(start_marker)
|
||||||
|
|
||||||
# Output regular content before the marker
|
# Output regular content before the marker
|
||||||
if start_idx > 0:
|
# if start_idx > 0:
|
||||||
content_part = remaining_chunk[:start_idx]
|
content_part = remaining_chunk[:start_idx]
|
||||||
events.append(
|
events.append(StreamingEvent(type="content", content=content_part))
|
||||||
StreamingEvent(type="content", content=content_part)
|
msg.content += content_part
|
||||||
)
|
|
||||||
msg.content += content_part
|
|
||||||
|
|
||||||
# Output tool call start event
|
# Output tool call start event
|
||||||
events.append(StreamingEvent(type="tool_call_start", content=""))
|
events.append(StreamingEvent(type="tool_call_start", content=""))
|
||||||
@ -355,6 +396,7 @@ def parse_chat_message(
|
|||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
content = input_text
|
content = input_text
|
||||||
|
|
||||||
|
# First check for the normal case with proper start and end markers
|
||||||
for pattern in reasoning_patterns:
|
for pattern in reasoning_patterns:
|
||||||
start_marker = pattern["start"]
|
start_marker = pattern["start"]
|
||||||
end_marker = pattern["end"]
|
end_marker = pattern["end"]
|
||||||
@ -372,6 +414,44 @@ def parse_chat_message(
|
|||||||
content = content[:start_idx] + content[end_idx + len(end_marker) :]
|
content = content[:start_idx] + content[end_idx + len(end_marker) :]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# If no reasoning content was found with the standard pattern, check for the
|
||||||
|
# special case
|
||||||
|
# where content starts with reasoning but has no start marker
|
||||||
|
if not reasoning_content:
|
||||||
|
for pattern in reasoning_patterns:
|
||||||
|
start_marker = pattern["start"]
|
||||||
|
end_marker = pattern["end"]
|
||||||
|
|
||||||
|
if end_marker in content:
|
||||||
|
# Check if this is at the beginning of the content or
|
||||||
|
# if there's no matching start marker before it
|
||||||
|
end_idx = content.find(end_marker)
|
||||||
|
start_marker = pattern["start"]
|
||||||
|
start_idx = content.find(start_marker)
|
||||||
|
|
||||||
|
# If no start marker or end marker appears before start marker
|
||||||
|
if start_idx == -1 or end_idx < start_idx:
|
||||||
|
# This is our special case - treat the content up to the end marker
|
||||||
|
# as reasoning
|
||||||
|
reasoning_content = string_strip(content[:end_idx])
|
||||||
|
|
||||||
|
# Remove reasoning part from original content
|
||||||
|
if extract_reasoning:
|
||||||
|
content = content[end_idx + len(end_marker) :]
|
||||||
|
break
|
||||||
|
elif start_marker in content:
|
||||||
|
# If there's a start marker but no end marker, treat the content
|
||||||
|
# as reasoning content
|
||||||
|
start_idx = content.find(start_marker)
|
||||||
|
reasoning_content = string_strip(
|
||||||
|
content[start_idx + len(start_marker) :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove reasoning part from original content
|
||||||
|
if extract_reasoning:
|
||||||
|
content = ""
|
||||||
|
break
|
||||||
|
|
||||||
# Parse tool calls
|
# Parse tool calls
|
||||||
tool_calls_text = ""
|
tool_calls_text = ""
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from ..parse_utils import parse_chat_message
|
from ..parse_utils import parse_chat_message
|
||||||
|
|
||||||
|
|
||||||
@ -357,8 +355,8 @@ def test_streaming_mode_without_tool_calls():
|
|||||||
|
|
||||||
# Verify final message should include tool call markers
|
# Verify final message should include tool call markers
|
||||||
assert (
|
assert (
|
||||||
"I will search for data.<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>Search complete." # noqa
|
"I will search for data.<|tool▁calls▁begin|>Tool call content"
|
||||||
== msg.content
|
"<|tool▁calls▁end|>Search complete." == msg.content # noqa
|
||||||
)
|
)
|
||||||
assert len(msg.tool_calls) == 0
|
assert len(msg.tool_calls) == 0
|
||||||
|
|
||||||
@ -449,48 +447,6 @@ def test_incomplete_markers():
|
|||||||
assert "Tool content" in msg2.streaming_state.get("tool_call_text", "")
|
assert "Tool content" in msg2.streaming_state.get("tool_call_text", "")
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_special_sections():
|
|
||||||
"""Test handling multiple special sections"""
|
|
||||||
input_text = """<think>Reasoning content 1</think>Regular content 1
|
|
||||||
<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>
|
|
||||||
Regular content 2<think>Reasoning content 2</think>End"""
|
|
||||||
|
|
||||||
result = parse_chat_message(input_text, extract_tool_calls=True)
|
|
||||||
|
|
||||||
# Verify only first reasoning content is extracted
|
|
||||||
assert "Reasoning content 1" == result.reasoning_content
|
|
||||||
assert (
|
|
||||||
"Regular content 1\n\nRegular content 2<think>Reasoning content 2</think>End"
|
|
||||||
== result.content
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use streaming processing to handle multiple reasoning parts
|
|
||||||
chunks = [
|
|
||||||
"<think>Reasoning content 1</think>Regular content 1\n",
|
|
||||||
"<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>\n",
|
|
||||||
"Regular content 2<think>Reasoning content 2</think>End",
|
|
||||||
]
|
|
||||||
|
|
||||||
msg = None
|
|
||||||
all_events = []
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
msg, events = parse_chat_message(
|
|
||||||
chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
|
|
||||||
)
|
|
||||||
all_events.extend(events)
|
|
||||||
|
|
||||||
# Verify event sequence contains two reasoning sections
|
|
||||||
reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
|
|
||||||
reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
|
|
||||||
|
|
||||||
assert reasoning_start_counts == 2
|
|
||||||
assert reasoning_end_counts == 2
|
|
||||||
|
|
||||||
# In streaming mode, reasoning content should accumulate
|
|
||||||
assert "Reasoning content 1Reasoning content 2" == msg.reasoning_content
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_streaming_patterns():
|
def test_custom_streaming_patterns():
|
||||||
"""Test custom streaming pattern markers"""
|
"""Test custom streaming pattern markers"""
|
||||||
custom_reasoning = [{"start": "{{thinking}}", "end": "{{/thinking}}"}]
|
custom_reasoning = [{"start": "{{thinking}}", "end": "{{/thinking}}"}]
|
||||||
@ -530,41 +486,239 @@ def test_custom_streaming_patterns():
|
|||||||
assert "tool_call_end" in event_types
|
assert "tool_call_end" in event_types
|
||||||
|
|
||||||
|
|
||||||
def test_alternating_reasoning_and_tool_calls():
|
def test_missing_start_token_non_streaming():
|
||||||
"""Test alternating between reasoning and tool calls in a single message"""
|
"""Test parsing messages with missing start token but having end token
|
||||||
# Use streaming to capture all sections
|
(non-streaming mode)
|
||||||
|
"""
|
||||||
|
input_text = """Model reasoning content without start token.
|
||||||
|
</think>
|
||||||
|
This is the regular content part."""
|
||||||
|
|
||||||
|
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||||
|
|
||||||
|
assert "This is the regular content part." == result.content
|
||||||
|
assert "Model reasoning content without start token." == result.reasoning_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_start_token_streaming():
|
||||||
|
"""Test parsing messages with missing start token but having end token
|
||||||
|
(streaming mode)
|
||||||
|
"""
|
||||||
chunks = [
|
chunks = [
|
||||||
"<think>First reasoning block</think>Content 1\n",
|
"Model reasoning content ",
|
||||||
"<|tool▁calls▁begin|>Tool call 1<|tool▁calls▁end|>\n",
|
"without start token.</think>",
|
||||||
"Content 2<think>Second reasoning block</think>\n",
|
"This is the regular content part.",
|
||||||
"<|tool▁calls▁begin|>Tool call 2<|tool▁calls▁end|>\n",
|
|
||||||
"Final content",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
msg = None
|
msg = None
|
||||||
all_events = []
|
all_events = []
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
msg, events = parse_chat_message(
|
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
|
||||||
chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
|
|
||||||
)
|
|
||||||
all_events.extend(events)
|
all_events.extend(events)
|
||||||
|
|
||||||
# Verify content is parsed correctly - note the double newlines
|
# Verify final message - match the expected result in test
|
||||||
assert "Content 1\n\nContent 2\n\nFinal content" == msg.content
|
assert "This is the regular content part." == msg.content
|
||||||
assert "First reasoning blockSecond reasoning block" == msg.reasoning_content
|
assert "Model reasoning content without start token." == msg.reasoning_content
|
||||||
|
|
||||||
# Count events by type
|
# Verify event sequence contains correct reasoning events
|
||||||
event_counts = {}
|
event_types = [e.type for e in all_events]
|
||||||
for e in all_events:
|
assert "reasoning_start" in event_types
|
||||||
event_counts[e.type] = event_counts.get(e.type, 0) + 1
|
assert "reasoning_content" in event_types
|
||||||
|
assert "reasoning_end" in event_types
|
||||||
assert event_counts.get("reasoning_start", 0) == 2
|
|
||||||
assert event_counts.get("reasoning_end", 0) == 2
|
|
||||||
assert event_counts.get("tool_call_start", 0) == 2
|
|
||||||
assert event_counts.get("tool_call_end", 0) == 2
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_missing_start_token_deepseek_chinese():
|
||||||
# Run tests
|
"""Test the DeepSeek example with Chinese content missing start token"""
|
||||||
pytest.main(["-v", "test_parse_utils.py"])
|
input_text = """您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手\
|
||||||
|
DeepSeek-R1。有关模型和产品的详细内容请参考官方文档。
|
||||||
|
</think>
|
||||||
|
您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手DeepSeek-R1。有关模型\
|
||||||
|
和产品的详细内容请参考官方文档。"""
|
||||||
|
|
||||||
|
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手DeepSeek-R1。"
|
||||||
|
"有关模型和产品的详细内容请参考官方文档。" == result.content
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手DeepSeek-R1。"
|
||||||
|
"有关模型和产品的详细内容请参考官方文档。" == result.reasoning_content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_missing_start_tokens():
|
||||||
|
"""Test multiple occurrences of missing start tokens in the same message"""
|
||||||
|
input_text = """First reasoning section.
|
||||||
|
</think>
|
||||||
|
Some regular content.
|
||||||
|
Second reasoning section.
|
||||||
|
</reasoning>
|
||||||
|
More regular content."""
|
||||||
|
|
||||||
|
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||||
|
|
||||||
|
# Note: In non-streaming mode, only the first matching reasoning content is
|
||||||
|
# extracted
|
||||||
|
assert (
|
||||||
|
"Some regular content.\nSecond reasoning section.\n</reasoning>\nMore regular "
|
||||||
|
"content." == result.content
|
||||||
|
)
|
||||||
|
assert "First reasoning section." == result.reasoning_content
|
||||||
|
|
||||||
|
# Use streaming to capture all sections
|
||||||
|
chunks = [
|
||||||
|
"First reasoning section.\n</think>\n",
|
||||||
|
"Some regular content.\n",
|
||||||
|
"Second reasoning section.\n</reasoning>\n",
|
||||||
|
"More regular content.",
|
||||||
|
]
|
||||||
|
|
||||||
|
msg = None
|
||||||
|
all_events = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
|
||||||
|
all_events.extend(events)
|
||||||
|
|
||||||
|
# In streaming mode, reasoning content should match the expected format
|
||||||
|
assert (
|
||||||
|
"First reasoning section.\n\nSome regular content.\nSecond reasoning section.\n"
|
||||||
|
== msg.reasoning_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify event sequence contains two reasoning sections
|
||||||
|
reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
|
||||||
|
reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
|
||||||
|
|
||||||
|
assert reasoning_start_counts == 2
|
||||||
|
assert reasoning_end_counts == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_start_token_with_tools():
|
||||||
|
"""Test missing reasoning start token with tool calls"""
|
||||||
|
input_text = """Analyzing user request to query weather information.
|
||||||
|
</think>
|
||||||
|
I'll look up the weather data for you.
|
||||||
|
|
||||||
|
<|tool▁calls▁begin|>
|
||||||
|
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"location": "Beijing",
|
||||||
|
"date": "2023-05-20"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
<|tool▁call▁end|>
|
||||||
|
<|tool▁calls▁end|>"""
|
||||||
|
|
||||||
|
result = parse_chat_message(
|
||||||
|
input_text, extract_reasoning=True, extract_tool_calls=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "I'll look up the weather data for you." in result.content
|
||||||
|
assert (
|
||||||
|
"Analyzing user request to query weather information."
|
||||||
|
== result.reasoning_content
|
||||||
|
)
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0]["name"] == "get_weather"
|
||||||
|
assert result.tool_calls[0]["arguments"]["location"] == "Beijing"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mixed_language_missing_start_token():
|
||||||
|
"""Test mixed Chinese and English content with missing start token"""
|
||||||
|
input_text = """这是一段中英文混合的思考内容 with both languages mixed together.
|
||||||
|
</think>
|
||||||
|
Here's the regular content with 中文 mixed in."""
|
||||||
|
|
||||||
|
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||||
|
|
||||||
|
assert "Here's the regular content with 中文 mixed in." == result.content
|
||||||
|
assert (
|
||||||
|
"这是一段中英文混合的思考内容 with both languages mixed together."
|
||||||
|
== result.reasoning_content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_mixed_language_missing_start():
|
||||||
|
"""Test streaming mixed language content with missing start token"""
|
||||||
|
chunks = [
|
||||||
|
"Analysis 分析: The user needs ",
|
||||||
|
"information about 关于天气的信息。</reasoning>",
|
||||||
|
"I'll provide weather information 我将提供天气信息。",
|
||||||
|
]
|
||||||
|
|
||||||
|
msg = None
|
||||||
|
all_events = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
|
||||||
|
all_events.extend(events)
|
||||||
|
|
||||||
|
# Verify final message includes mixed language content properly parsed
|
||||||
|
assert "I'll provide weather information 我将提供天气信息。" == msg.content
|
||||||
|
assert (
|
||||||
|
"Analysis 分析: The user needs information about 关于天气的信息。"
|
||||||
|
== msg.reasoning_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify events sequence
|
||||||
|
reasoning_events = [e for e in all_events if e.type.startswith("reasoning_")]
|
||||||
|
assert len(reasoning_events) >= 3 # At least start, content, and end events
|
||||||
|
|
||||||
|
|
||||||
|
def test_chinese_pattern_missing_start():
|
||||||
|
"""Test Chinese pattern with missing start token"""
|
||||||
|
input_text = """这里是模型的思考内容,但是没有开始标记。
|
||||||
|
</思考>
|
||||||
|
这是正常的响应内容。"""
|
||||||
|
|
||||||
|
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||||
|
|
||||||
|
assert "这是正常的响应内容。" == result.content
|
||||||
|
assert "这里是模型的思考内容,但是没有开始标记。" == result.reasoning_content
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# def test_multiple_special_sections():
|
||||||
|
# """Test handling multiple special sections"""
|
||||||
|
# input_text = """<think>Reasoning content 1</think>Regular content 1
|
||||||
|
# <|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>
|
||||||
|
# Regular content 2<think>Reasoning content 2</think>End"""
|
||||||
|
#
|
||||||
|
# result = parse_chat_message(input_text, extract_tool_calls=True)
|
||||||
|
#
|
||||||
|
# # Verify only first reasoning content is extracted
|
||||||
|
# assert "Reasoning content 1" == result.reasoning_content
|
||||||
|
# assert (
|
||||||
|
# "Regular content 1\n\nRegular content 2<think>Reasoning content 2</think>End"
|
||||||
|
# == result.content
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # Use streaming processing to handle multiple reasoning parts
|
||||||
|
# chunks = [
|
||||||
|
# "<think>Reasoning content 1</think>Regular content 1\n",
|
||||||
|
# "<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>\n",
|
||||||
|
# "Regular content 2<think>Reasoning content 2</think>End",
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# msg = None
|
||||||
|
# all_events = []
|
||||||
|
#
|
||||||
|
# for chunk in chunks:
|
||||||
|
# msg, events = parse_chat_message(
|
||||||
|
# chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
|
||||||
|
# )
|
||||||
|
# all_events.extend(events)
|
||||||
|
#
|
||||||
|
# # Verify event sequence contains two reasoning sections
|
||||||
|
# reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
|
||||||
|
# reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
|
||||||
|
#
|
||||||
|
# assert reasoning_start_counts == 2
|
||||||
|
# assert reasoning_end_counts == 2
|
||||||
|
#
|
||||||
|
# # In streaming mode, reasoning content should match the expected format
|
||||||
|
# assert "Reasoning content 1Reasoning content 2" == msg.reasoning_content
|
||||||
|
@ -113,7 +113,7 @@ class HFEmbeddingDeployModelParameters(EmbeddingDeployModelParameters):
|
|||||||
@property
|
@property
|
||||||
def real_provider_model_name(self) -> str:
|
def real_provider_model_name(self) -> str:
|
||||||
"""Get the real provider model name."""
|
"""Get the real provider model name."""
|
||||||
return self.path or self.name
|
return self.real_model_path or self.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def real_model_path(self) -> Optional[str]:
|
def real_model_path(self) -> Optional[str]:
|
||||||
|
@ -40,8 +40,7 @@ class ClickhouseParameters(BaseDatasourceParameters):
|
|||||||
user: str = field(metadata={"help": _("Database user to connect")})
|
user: str = field(metadata={"help": _("Database user to connect")})
|
||||||
database: str = field(metadata={"help": _("Database name")})
|
database: str = field(metadata={"help": _("Database name")})
|
||||||
engine: str = field(
|
engine: str = field(
|
||||||
default="MergeTree",
|
default="MergeTree", metadata={"help": _("Storage engine, e.g., MergeTree")}
|
||||||
metadata={"help": _("Storage engine, e.g., MergeTree")}
|
|
||||||
)
|
)
|
||||||
password: str = field(
|
password: str = field(
|
||||||
default="${env:DBGPT_DB_PASSWORD}",
|
default="${env:DBGPT_DB_PASSWORD}",
|
||||||
|
@ -13,16 +13,22 @@ readme = "README.md"
|
|||||||
requires-python = ">= 3.10"
|
requires-python = ">= 3.10"
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
dbgpt-accelerator = { workspace = true }
|
|
||||||
dbgpt = { workspace = true }
|
dbgpt = { workspace = true }
|
||||||
dbgpt-client = { workspace = true }
|
dbgpt-client = { workspace = true }
|
||||||
dbgpt-ext = { workspace = true }
|
dbgpt-ext = { workspace = true }
|
||||||
dbgpt-serve = { workspace = true }
|
dbgpt-serve = { workspace = true }
|
||||||
dbgpt-app = { workspace = true }
|
dbgpt-app = { workspace = true }
|
||||||
|
dbgpt-acc-auto = { workspace = true }
|
||||||
|
dbgpt-acc-flash-attn = { workspace = true }
|
||||||
|
|
||||||
[tool.uv.workspace]
|
[tool.uv.workspace]
|
||||||
members = [
|
members = [
|
||||||
"packages/dbgpt-*"
|
"packages/dbgpt-app",
|
||||||
|
"packages/dbgpt-client",
|
||||||
|
"packages/dbgpt-core",
|
||||||
|
"packages/dbgpt-ext",
|
||||||
|
"packages/dbgpt-serve",
|
||||||
|
"packages/dbgpt-accelerator/*"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
@ -65,4 +71,4 @@ select = ["E", "F", "I"]
|
|||||||
|
|
||||||
[tool.ruff.isort]
|
[tool.ruff.isort]
|
||||||
# Specify the local modules (first-party)
|
# Specify the local modules (first-party)
|
||||||
known-first-party = ["dbgpt", "dbgpt_accelerator", "dbgpt_client", "dbgpt_ext", "dbgpt_serve", "dbgpt_app"]
|
known-first-party = ["dbgpt", "dbgpt_acc_auto", "dbgpt_client", "dbgpt_ext", "dbgpt_serve", "dbgpt_app"]
|
||||||
|
Loading…
Reference in New Issue
Block a user