mistral[patch]: translate tool call IDs to mistral compatible format (#24668)

Mistral appears to have added validation for the format of its tool call
IDs:

`{"object":"error","message":"Tool call id was abc123 but must be a-z,
A-Z, 0-9, with a length of
9.","type":"invalid_request_error","param":null,"code":null}`

This breaks compatibility of messages from other providers. Here we add
a function that converts any string to a Mistral-valid tool call ID, and
apply it to incoming messages.
This commit is contained in:
ccurme
2024-07-25 12:39:32 -04:00
committed by GitHub
parent 38d30e285a
commit dfbd12b384
2 changed files with 63 additions and 8 deletions

View File

@@ -21,6 +21,8 @@ from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_message_to_mistral_chat_message,
_convert_mistral_chat_message_to_message,
_convert_tool_call_id_to_mistral_compatible,
_is_valid_mistral_tool_call_id,
)
os.environ["MISTRAL_API_KEY"] = "foo"
@@ -128,7 +130,7 @@ async def test_astream_with_callback() -> None:
def test__convert_dict_to_message_tool_call() -> None:
raw_tool_call = {
"id": "abc123",
"id": "ssAbar4Dr",
"function": {
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
@@ -143,7 +145,7 @@ def test__convert_dict_to_message_tool_call() -> None:
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="abc123",
id="ssAbar4Dr",
type="tool_call",
)
],
@@ -154,14 +156,14 @@ def test__convert_dict_to_message_tool_call() -> None:
# Test malformed tool call
raw_tool_calls = [
{
"id": "def456",
"id": "pL5rEGzxe",
"function": {
"arguments": '{"name": "Sally", "hair_color": "green"}',
"name": "GenerateUsername",
},
},
{
"id": "abc123",
"id": "ssAbar4Dr",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
@@ -178,7 +180,7 @@ def test__convert_dict_to_message_tool_call() -> None:
name="GenerateUsername",
args="oops",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
id="abc123",
id="ssAbar4Dr",
type="invalid_tool_call",
),
],
@@ -186,7 +188,7 @@ def test__convert_dict_to_message_tool_call() -> None:
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="def456",
id="pL5rEGzxe",
type="tool_call",
),
],
@@ -201,3 +203,18 @@ def test_custom_token_counting() -> None:
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]
def test_tool_id_conversion() -> None:
assert _is_valid_mistral_tool_call_id("ssAbar4Dr")
assert not _is_valid_mistral_tool_call_id("abc123")
assert not _is_valid_mistral_tool_call_id("call_JIIjI55tTipFFzpcP8re3BpM")
result_map = {
"ssAbar4Dr": "ssAbar4Dr",
"abc123": "pL5rEGzxe",
"call_JIIjI55tTipFFzpcP8re3BpM": "8kxAQvoED",
}
for input_id, expected_output in result_map.items():
assert _convert_tool_call_id_to_mistral_compatible(input_id) == expected_output
assert _is_valid_mistral_tool_call_id(expected_output)