Community: Add mistral oss model support to azureml endpoints, plus configurable timeout (#19123)

- **Description:** There was no formatter for mistral models for Azure
ML endpoints. Adding that, plus a configurable timeout (it was hard
coded before)
- **Dependencies:** none
- **Twitter handle:** @tjaffri @docugami
This commit is contained in:
Taqi Jaffri
2024-03-18 21:10:42 -07:00
committed by GitHub
parent 07de4abe70
commit 044bc22acc
2 changed files with 62 additions and 2 deletions

View File

@@ -143,6 +143,44 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class MistralChatContentFormatter(LlamaChatContentFormatter):
"""Content formatter for `Mistral`."""
def format_messages_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [self._convert_message_to_dict(message) for message in messages]
if chat_messages and chat_messages[0]["role"] == "system":
# Mistral OSS models do not explicitly support system prompts, so we have to
# stash in the first user prompt
chat_messages[1]["content"] = (
chat_messages[0]["content"] + "\n\n" + chat_messages[1]["content"]
)
del chat_messages[0]
if api_type == AzureMLEndpointApiType.realtime:
request_payload = json.dumps(
{
"input_data": {
"input_string": chat_messages,
"parameters": model_kwargs,
}
}
)
elif api_type == AzureMLEndpointApiType.serverless:
request_payload = json.dumps({"messages": chat_messages, **model_kwargs})
else:
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
"""Azure ML Online Endpoint chat models.