huggingface: fix tool argument serialization in _convert_TGI_message_to_LC_message (#26075)

Currently `_convert_TGI_message_to_LC_message` replaces `'` in the tool
arguments, so an argument like "It's" will be converted to `It"s` and
could cause a json parser to fail.

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Vadym Barda <vadym@langchain.dev>
This commit is contained in:
Wang, Yi 2024-12-12 10:34:32 +08:00 committed by GitHub
parent 5a31792bf1
commit d834c6b618
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 4 deletions

View File

@ -1,5 +1,6 @@
"""Hugging Face Chat Wrapper."""
import json
from dataclasses import dataclass
from typing import (
Any,
@ -106,9 +107,10 @@ def _convert_TGI_message_to_LC_message(
additional_kwargs: Dict = {}
if tool_calls := _message.tool_calls:
if "arguments" in tool_calls[0]["function"]:
functions_string = str(tool_calls[0]["function"].pop("arguments"))
corrected_functions = functions_string.replace("'", '"')
tool_calls[0]["function"]["arguments"] = corrected_functions
functions = tool_calls[0]["function"].pop("arguments")
tool_calls[0]["function"]["arguments"] = json.dumps(
functions, ensure_ascii=False
)
additional_kwargs["tool_calls"] = tool_calls
return AIMessage(content=content, additional_kwargs=additional_kwargs)

View File

@ -66,7 +66,7 @@ def test_convert_message_to_chat_message(
TGI_MESSAGE(
role="assistant",
content="",
tool_calls=[{"function": {"arguments": "'function string'"}}],
tool_calls=[{"function": {"arguments": "function string"}}],
),
AIMessage(
content="",
@ -75,6 +75,23 @@ def test_convert_message_to_chat_message(
},
),
),
(
TGI_MESSAGE(
role="assistant",
content="",
tool_calls=[
{"function": {"arguments": {"answer": "function's string"}}}
],
),
AIMessage(
content="",
additional_kwargs={
"tool_calls": [
{"function": {"arguments": '{"answer": "function\'s string"}'}}
]
},
),
),
],
)
def test_convert_TGI_message_to_LC_message(