mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
fix: LLM mimicking Unicode responses due to forced Unicode conversion of non-ASCII characters. (#32222)
fix: Fix LLM mimicking Unicode responses due to forced Unicode conversion of non-ASCII characters. - **Description:** This PR fixes an issue where the LLM would mimic Unicode responses due to forced Unicode conversion of non-ASCII characters in tool calls. The fix involves disabling the `ensure_ascii` flag in `json.dumps()` when converting tool calls to OpenAI format. - **Issue:** Fixes ↓↓↓ input: ```json {'role': 'assistant', 'tool_calls': [{'type': 'function', 'id': 'call_nv9trcehdpihr21zj9po19vq', 'function': {'name': 'create_customer', 'arguments': '{"customer_name": "你好啊集团"}'}}]} ``` output: ```json {'role': 'assistant', 'tool_calls': [{'type': 'function', 'id': 'call_nv9trcehdpihr21zj9po19vq', 'function': {'name': 'create_customer', 'arguments': '{"customer_name": "\\u4f60\\u597d\\u554a\\u96c6\\u56e2"}'}}]} ``` then: llm will mimic outputting unicode. Unicode's vast number of symbols can lengthen LLM responses, leading to slower performance. <img width="686" height="277" alt="image" src="https://github.com/user-attachments/assets/28f3b007-3964-4455-bee2-68f86ac1906d" /> --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
d53ebf367e
commit
0d6f915442
@ -56,7 +56,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
1418
libs/cli/uv.lock
1418
libs/cli/uv.lock
File diff suppressed because it is too large
Load Diff
@ -1176,7 +1176,9 @@ def convert_to_openai_messages(
|
|||||||
"id": block["id"],
|
"id": block["id"],
|
||||||
"function": {
|
"function": {
|
||||||
"name": block["name"],
|
"name": block["name"],
|
||||||
"arguments": json.dumps(block["input"]),
|
"arguments": json.dumps(
|
||||||
|
block["input"], ensure_ascii=False
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -1550,7 +1552,7 @@ def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
|
|||||||
"id": tool_call["id"],
|
"id": tool_call["id"],
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for tool_call in tool_calls
|
for tool_call in tool_calls
|
||||||
|
@ -1121,6 +1121,33 @@ def test_convert_to_openai_messages_tool_use() -> None:
|
|||||||
assert result[0]["tool_calls"][0]["function"]["arguments"] == json.dumps({"a": "b"})
|
assert result[0]["tool_calls"][0]["function"]["arguments"] == json.dumps({"a": "b"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_openai_messages_tool_use_unicode() -> None:
|
||||||
|
"""Test that Unicode characters in tool call args are preserved correctly."""
|
||||||
|
messages = [
|
||||||
|
AIMessage(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "123",
|
||||||
|
"name": "create_customer",
|
||||||
|
"input": {"customer_name": "你好啊集团"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
result = convert_to_openai_messages(messages, text_format="block")
|
||||||
|
assert result[0]["tool_calls"][0]["type"] == "function"
|
||||||
|
assert result[0]["tool_calls"][0]["id"] == "123"
|
||||||
|
assert result[0]["tool_calls"][0]["function"]["name"] == "create_customer"
|
||||||
|
# Ensure Unicode characters are preserved, not escaped as \\uXXXX
|
||||||
|
arguments_str = result[0]["tool_calls"][0]["function"]["arguments"]
|
||||||
|
parsed_args = json.loads(arguments_str)
|
||||||
|
assert parsed_args["customer_name"] == "你好啊集团"
|
||||||
|
# Also ensure the raw JSON string contains Unicode, not escaped sequences
|
||||||
|
assert "你好啊集团" in arguments_str
|
||||||
|
assert "\\u4f60" not in arguments_str # Should not contain escaped Unicode
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_openai_messages_json() -> None:
|
def test_convert_to_openai_messages_json() -> None:
|
||||||
json_data = {"key": "value"}
|
json_data = {"key": "value"}
|
||||||
messages = [HumanMessage(content=[{"type": "json", "json": json_data}])]
|
messages = [HumanMessage(content=[{"type": "json", "json": json_data}])]
|
||||||
|
@ -67,7 +67,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -69,7 +69,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -55,7 +55,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -98,7 +98,7 @@ print("Similar Results:", similar_results)
|
|||||||
All Exa tools support the following common parameters:
|
All Exa tools support the following common parameters:
|
||||||
|
|
||||||
- `num_results` (1-100): Number of search results to return
|
- `num_results` (1-100): Number of search results to return
|
||||||
- `type`: Search type - "neural", "keyword", or "auto"
|
- `type`: Search type - "neural", "keyword", or "auto"
|
||||||
- `livecrawl`: Live crawling mode - "always", "fallback", or "never"
|
- `livecrawl`: Live crawling mode - "always", "fallback", or "never"
|
||||||
- `summary`: Get AI-generated summaries (True/False or custom prompt dict)
|
- `summary`: Get AI-generated summaries (True/False or custom prompt dict)
|
||||||
- `text_contents_options`: Dict to limit text length (e.g. `{"max_characters": 2000}`)
|
- `text_contents_options`: Dict to limit text length (e.g. `{"max_characters": 2000}`)
|
||||||
|
@ -54,7 +54,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -1070,7 +1070,7 @@ def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:
|
|||||||
"id": tool_call["id"],
|
"id": tool_call["id"],
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -653,7 +653,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.68"
|
version = "0.3.72"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
@ -669,7 +669,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
||||||
{ name = "langsmith", specifier = ">=0.3.45" },
|
{ name = "langsmith", specifier = ">=0.3.45" },
|
||||||
{ name = "packaging", specifier = ">=23.2,<25" },
|
{ name = "packaging", specifier = ">=23.2" },
|
||||||
{ name = "pydantic", specifier = ">=2.7.4" },
|
{ name = "pydantic", specifier = ">=2.7.4" },
|
||||||
{ name = "pyyaml", specifier = ">=5.3" },
|
{ name = "pyyaml", specifier = ">=5.3" },
|
||||||
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
||||||
|
@ -1339,7 +1339,7 @@ def _lc_tool_call_to_groq_tool_call(tool_call: ToolCall) -> dict:
|
|||||||
"id": tool_call["id"],
|
"id": tool_call["id"],
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,7 +51,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -331,7 +331,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.68"
|
version = "0.3.72"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
@ -347,7 +347,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
||||||
{ name = "langsmith", specifier = ">=0.3.45" },
|
{ name = "langsmith", specifier = ">=0.3.45" },
|
||||||
{ name = "packaging", specifier = ">=23.2,<25" },
|
{ name = "packaging", specifier = ">=23.2" },
|
||||||
{ name = "pydantic", specifier = ">=2.7.4" },
|
{ name = "pydantic", specifier = ">=2.7.4" },
|
||||||
{ name = "pyyaml", specifier = ">=5.3" },
|
{ name = "pyyaml", specifier = ">=5.3" },
|
||||||
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
||||||
|
@ -88,7 +88,7 @@ def _lc_tool_call_to_hf_tool_call(tool_call: ToolCall) -> dict:
|
|||||||
"id": tool_call["id"],
|
"id": tool_call["id"],
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +65,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -296,7 +296,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
|||||||
result: dict[str, Any] = {
|
result: dict[str, Any] = {
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _id := tool_call.get("id"):
|
if _id := tool_call.get("id"):
|
||||||
|
@ -55,7 +55,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -350,7 +350,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.68"
|
version = "0.3.72"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
@ -366,7 +366,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
{ name = "jsonpatch", specifier = ">=1.33,<2.0" },
|
||||||
{ name = "langsmith", specifier = ">=0.3.45" },
|
{ name = "langsmith", specifier = ">=0.3.45" },
|
||||||
{ name = "packaging", specifier = ">=23.2,<25" },
|
{ name = "packaging", specifier = ">=23.2" },
|
||||||
{ name = "pydantic", specifier = ">=2.7.4" },
|
{ name = "pydantic", specifier = ">=2.7.4" },
|
||||||
{ name = "pyyaml", specifier = ">=5.3" },
|
{ name = "pyyaml", specifier = ">=5.3" },
|
||||||
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
{ name = "tenacity", specifier = ">=8.1.0,!=8.4.0,<10.0.0" },
|
||||||
|
@ -51,7 +51,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -54,7 +54,6 @@ select = [
|
|||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
"COM", # flake8-commas
|
"COM", # flake8-commas
|
||||||
"D", # pydocstyle
|
"D", # pydocstyle
|
||||||
"DOC", # pydoclint
|
|
||||||
"E", # pycodestyle error
|
"E", # pycodestyle error
|
||||||
"EM", # flake8-errmsg
|
"EM", # flake8-errmsg
|
||||||
"F", # pyflakes
|
"F", # pyflakes
|
||||||
|
@ -211,7 +211,7 @@ def _convert_from_v03_ai_message(message: AIMessage) -> AIMessage:
|
|||||||
function_call = {
|
function_call = {
|
||||||
"type": "function_call",
|
"type": "function_call",
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
"call_id": tool_call["id"],
|
"call_id": tool_call["id"],
|
||||||
}
|
}
|
||||||
if function_call_ids is not None and (
|
if function_call_ids is not None and (
|
||||||
|
@ -3178,7 +3178,7 @@ def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
|||||||
"id": tool_call["id"],
|
"id": tool_call["id"],
|
||||||
"function": {
|
"function": {
|
||||||
"name": tool_call["name"],
|
"name": tool_call["name"],
|
||||||
"arguments": json.dumps(tool_call["args"]),
|
"arguments": json.dumps(tool_call["args"], ensure_ascii=False),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2530,6 +2530,32 @@ def test_make_computer_call_output_from_message() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_lc_tool_call_to_openai_tool_call_unicode() -> None:
|
||||||
|
"""Test that Unicode characters in tool call args are preserved correctly."""
|
||||||
|
from langchain_openai.chat_models.base import _lc_tool_call_to_openai_tool_call
|
||||||
|
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="call_123",
|
||||||
|
name="create_customer",
|
||||||
|
args={"customer_name": "你好啊集团"},
|
||||||
|
type="tool_call",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _lc_tool_call_to_openai_tool_call(tool_call)
|
||||||
|
|
||||||
|
assert result["type"] == "function"
|
||||||
|
assert result["id"] == "call_123"
|
||||||
|
assert result["function"]["name"] == "create_customer"
|
||||||
|
|
||||||
|
# Ensure Unicode characters are preserved, not escaped as \\uXXXX
|
||||||
|
arguments_str = result["function"]["arguments"]
|
||||||
|
parsed_args = json.loads(arguments_str)
|
||||||
|
assert parsed_args["customer_name"] == "你好啊集团"
|
||||||
|
# Also ensure the raw JSON string contains Unicode, not escaped sequences
|
||||||
|
assert "你好啊集团" in arguments_str
|
||||||
|
assert "\\u4f60" not in arguments_str # Should not contain escaped Unicode
|
||||||
|
|
||||||
|
|
||||||
def test_extra_body_parameter() -> None:
|
def test_extra_body_parameter() -> None:
|
||||||
"""Test that extra_body parameter is properly included in request payload."""
|
"""Test that extra_body parameter is properly included in request payload."""
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
|
@ -480,7 +480,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.70"
|
version = "0.3.72"
|
||||||
source = { editable = "../../core" }
|
source = { editable = "../../core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
|
@ -127,6 +127,20 @@ def _validate_tool_call_message_no_args(message: BaseMessage) -> None:
|
|||||||
assert tool_call.get("type") == "tool_call"
|
assert tool_call.get("type") == "tool_call"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def unicode_customer(customer_name: str, description: str) -> str:
|
||||||
|
"""Tool for creating a customer with Unicode name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_name: The customer's name in their native language.
|
||||||
|
description: Description of the customer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A confirmation message about the customer creation.
|
||||||
|
"""
|
||||||
|
return f"Created customer: {customer_name} - {description}"
|
||||||
|
|
||||||
|
|
||||||
class ChatModelIntegrationTests(ChatModelTests):
|
class ChatModelIntegrationTests(ChatModelTests):
|
||||||
"""Base class for chat model integration tests.
|
"""Base class for chat model integration tests.
|
||||||
|
|
||||||
@ -2900,3 +2914,95 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
|
def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
|
||||||
""":private:"""
|
""":private:"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_unicode_tool_call_integration(
|
||||||
|
self,
|
||||||
|
model: BaseChatModel,
|
||||||
|
*,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
force_tool_call: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Generic integration test for Unicode characters in tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The chat model to test
|
||||||
|
tool_choice: Tool choice parameter to pass to bind_tools (provider-specific)
|
||||||
|
force_tool_call: Whether to force a tool call (use tool_choice=True if None)
|
||||||
|
|
||||||
|
Tests that Unicode characters in tool call arguments are preserved correctly,
|
||||||
|
not escaped as \\uXXXX sequences.
|
||||||
|
"""
|
||||||
|
if not self.has_tool_calling:
|
||||||
|
pytest.skip("Test requires tool calling support.")
|
||||||
|
|
||||||
|
# Configure tool choice based on provider capabilities
|
||||||
|
if tool_choice is None and force_tool_call:
|
||||||
|
tool_choice = "any"
|
||||||
|
|
||||||
|
if tool_choice is not None:
|
||||||
|
llm_with_tool = model.bind_tools(
|
||||||
|
[unicode_customer], tool_choice=tool_choice
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llm_with_tool = model.bind_tools([unicode_customer])
|
||||||
|
|
||||||
|
# Test with Chinese characters
|
||||||
|
msgs = [
|
||||||
|
HumanMessage(
|
||||||
|
"Create a customer named '你好啊集团' (Hello Group) - a Chinese "
|
||||||
|
"technology company"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ai_msg = llm_with_tool.invoke(msgs)
|
||||||
|
|
||||||
|
assert isinstance(ai_msg, AIMessage)
|
||||||
|
assert isinstance(ai_msg.tool_calls, list)
|
||||||
|
|
||||||
|
if force_tool_call:
|
||||||
|
assert len(ai_msg.tool_calls) >= 1, (
|
||||||
|
f"Expected at least 1 tool call, got {len(ai_msg.tool_calls)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if ai_msg.tool_calls:
|
||||||
|
tool_call = ai_msg.tool_calls[0]
|
||||||
|
assert tool_call["name"] == "unicode_customer"
|
||||||
|
assert "args" in tool_call
|
||||||
|
|
||||||
|
# Verify Unicode characters are properly handled
|
||||||
|
args = tool_call["args"]
|
||||||
|
assert "customer_name" in args
|
||||||
|
customer_name = args["customer_name"]
|
||||||
|
|
||||||
|
# The model should include the Unicode characters, not escaped sequences
|
||||||
|
assert (
|
||||||
|
"你好" in customer_name
|
||||||
|
or "你" in customer_name
|
||||||
|
or "好" in customer_name
|
||||||
|
), f"Unicode characters not found in: {customer_name}"
|
||||||
|
|
||||||
|
# Test with additional Unicode examples - Japanese
|
||||||
|
msgs_jp = [
|
||||||
|
HumanMessage(
|
||||||
|
"Create a customer named 'こんにちは株式会社' (Hello Corporation) - a "
|
||||||
|
"Japanese company"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ai_msg_jp = llm_with_tool.invoke(msgs_jp)
|
||||||
|
|
||||||
|
assert isinstance(ai_msg_jp, AIMessage)
|
||||||
|
|
||||||
|
if force_tool_call:
|
||||||
|
assert len(ai_msg_jp.tool_calls) >= 1
|
||||||
|
|
||||||
|
if ai_msg_jp.tool_calls:
|
||||||
|
tool_call_jp = ai_msg_jp.tool_calls[0]
|
||||||
|
args_jp = tool_call_jp["args"]
|
||||||
|
customer_name_jp = args_jp["customer_name"]
|
||||||
|
|
||||||
|
# Verify Japanese Unicode characters are preserved
|
||||||
|
assert (
|
||||||
|
"こんにちは" in customer_name_jp
|
||||||
|
or "株式会社" in customer_name_jp
|
||||||
|
or "こ" in customer_name_jp
|
||||||
|
or "ん" in customer_name_jp
|
||||||
|
), f"Japanese Unicode characters not found in: {customer_name_jp}"
|
||||||
|
@ -68,6 +68,9 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
# Replace this with actual logic to generate a response from a list
|
# Replace this with actual logic to generate a response from a list
|
||||||
# of messages.
|
# of messages.
|
||||||
|
_ = stop # Mark as used to avoid unused variable warning
|
||||||
|
_ = run_manager # Mark as used to avoid unused variable warning
|
||||||
|
_ = kwargs # Mark as used to avoid unused variable warning
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
tokens = last_message.content[: self.parrot_buffer_length]
|
tokens = last_message.content[: self.parrot_buffer_length]
|
||||||
ct_input_tokens = sum(len(message.content) for message in messages)
|
ct_input_tokens = sum(len(message.content) for message in messages)
|
||||||
@ -114,6 +117,8 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
downstream and understand why generation stopped.
|
downstream and understand why generation stopped.
|
||||||
run_manager: A run manager with callbacks for the LLM.
|
run_manager: A run manager with callbacks for the LLM.
|
||||||
"""
|
"""
|
||||||
|
_ = stop # Mark as used to avoid unused variable warning
|
||||||
|
_ = kwargs # Mark as used to avoid unused variable warning
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
tokens = str(last_message.content[: self.parrot_buffer_length])
|
tokens = str(last_message.content[: self.parrot_buffer_length])
|
||||||
ct_input_tokens = sum(len(message.content) for message in messages)
|
ct_input_tokens = sum(len(message.content) for message in messages)
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
"""Test the standard tests on the custom chat model in the docs."""
|
"""Test the standard tests on the custom chat model in the docs."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||||
|
|
||||||
@ -24,3 +29,12 @@ class TestChatParrotLinkIntegration(ChatModelIntegrationTests):
|
|||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"model": "bird-brain-001", "temperature": 0, "parrot_buffer_length": 50}
|
return {"model": "bird-brain-001", "temperature": 0, "parrot_buffer_length": 50}
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="ChatParrotLink doesn't implement bind_tools method")
|
||||||
|
def test_unicode_tool_call_integration(
|
||||||
|
self,
|
||||||
|
model: BaseChatModel,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
force_tool_call: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Expected failure as ChatParrotLink doesn't support tool calling yet."""
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user