mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
huggingface: fix tool argument serialization in _convert_TGI_message_to_LC_message (#26075)
Currently `_convert_TGI_message_to_LC_message` replaces `'` in the tool arguments, so an argument like "It's" will be converted to `It"s` and could cause a json parser to fail. --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
parent
5a31792bf1
commit
d834c6b618
@ -1,5 +1,6 @@
|
|||||||
"""Hugging Face Chat Wrapper."""
|
"""Hugging Face Chat Wrapper."""
|
||||||
|
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -106,9 +107,10 @@ def _convert_TGI_message_to_LC_message(
|
|||||||
additional_kwargs: Dict = {}
|
additional_kwargs: Dict = {}
|
||||||
if tool_calls := _message.tool_calls:
|
if tool_calls := _message.tool_calls:
|
||||||
if "arguments" in tool_calls[0]["function"]:
|
if "arguments" in tool_calls[0]["function"]:
|
||||||
functions_string = str(tool_calls[0]["function"].pop("arguments"))
|
functions = tool_calls[0]["function"].pop("arguments")
|
||||||
corrected_functions = functions_string.replace("'", '"')
|
tool_calls[0]["function"]["arguments"] = json.dumps(
|
||||||
tool_calls[0]["function"]["arguments"] = corrected_functions
|
functions, ensure_ascii=False
|
||||||
|
)
|
||||||
additional_kwargs["tool_calls"] = tool_calls
|
additional_kwargs["tool_calls"] = tool_calls
|
||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ def test_convert_message_to_chat_message(
|
|||||||
TGI_MESSAGE(
|
TGI_MESSAGE(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
tool_calls=[{"function": {"arguments": "'function string'"}}],
|
tool_calls=[{"function": {"arguments": "function string"}}],
|
||||||
),
|
),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="",
|
content="",
|
||||||
@ -75,6 +75,23 @@ def test_convert_message_to_chat_message(
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
TGI_MESSAGE(
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{"function": {"arguments": {"answer": "function's string"}}}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{"function": {"arguments": '{"answer": "function\'s string"}'}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_convert_TGI_message_to_LC_message(
|
def test_convert_TGI_message_to_LC_message(
|
||||||
|
Loading…
Reference in New Issue
Block a user