mistral[patch]: translate tool call IDs to mistral compatible format (#24668)

Mistral appears to have added validation for the format of its tool call
IDs:

`{"object":"error","message":"Tool call id was abc123 but must be a-z,
A-Z, 0-9, with a length of
9.","type":"invalid_request_error","param":null,"code":null}`

This breaks compatibility of messages from other providers. Here we add
a function that converts any string to a Mistral-valid tool call ID, and
apply it to incoming messages.
This commit is contained in:
ccurme
2024-07-25 12:39:32 -04:00
committed by GitHub
parent 38d30e285a
commit dfbd12b384
2 changed files with 63 additions and 8 deletions

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import hashlib
import json
import logging
import re
import uuid
from operator import itemgetter
from typing import (
@@ -77,6 +79,9 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
logger = logging.getLogger(__name__)
# Mistral enforces a specific pattern for tool call IDs
TOOL_CALL_ID_PATTERN = re.compile(r"^[a-zA-Z0-9]{9}$")
def _create_retry_decorator(
llm: ChatMistralAI,
@@ -92,6 +97,39 @@ def _create_retry_decorator(
)
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9"""
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
def _base62_encode(num: int) -> str:
"""Encodes a number in base62 and ensures result is of a specified length."""
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
if num == 0:
return base62[0]
arr = []
base = len(base62)
while num:
num, rem = divmod(num, base)
arr.append(base62[rem])
arr.reverse()
return "".join(arr)
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
"""Convert a tool call ID to a Mistral-compatible format"""
if _is_valid_mistral_tool_call_id(tool_call_id):
return tool_call_id
else:
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
hash_int = int.from_bytes(hash_bytes, byteorder="big")
base62_str = _base62_encode(hash_int)
if len(base62_str) >= 9:
return base62_str[:9]
else:
return base62_str.rjust(9, "0")
def _convert_mistral_chat_message_to_message(
_message: Dict,
) -> BaseMessage:
@@ -246,7 +284,7 @@ def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
}
}
if _id := tool_call.get("id"):
result["id"] = _id
result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)
return result
@@ -260,7 +298,7 @@ def _format_invalid_tool_call_for_mistral(invalid_tool_call: InvalidToolCall) ->
}
}
if _id := invalid_tool_call.get("id"):
result["id"] = _id
result["id"] = _convert_tool_call_id_to_mistral_compatible(_id)
return result