From e64cf1aba473dd87121135ed511df4f52343b9b4 Mon Sep 17 00:00:00 2001 From: Rodrigo Nogueira <121117945+rodrigo-f-nogueira@users.noreply.github.com> Date: Sat, 16 Mar 2024 19:18:56 -0300 Subject: [PATCH] community: Add model argument for maritalk models and better error handling (#19187) --- docs/docs/integrations/chat/maritalk.ipynb | 1 + .../chat_models/maritalk.py | 63 ++++++++++++++----- 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/docs/docs/integrations/chat/maritalk.ipynb b/docs/docs/integrations/chat/maritalk.ipynb index bd2a04700f8..5ae77c16242 100644 --- a/docs/docs/integrations/chat/maritalk.ipynb +++ b/docs/docs/integrations/chat/maritalk.ipynb @@ -65,6 +65,7 @@ "from langchain_core.output_parsers import StrOutputParser\n", "\n", "llm = ChatMaritalk(\n", + " model=\"sabia-2-medium\", # Available models: sabia-2-small and sabia-2-medium\n", " api_key=\"\", # Insert your API key here\n", " temperature=0.7,\n", " max_tokens=100,\n", diff --git a/libs/community/langchain_community/chat_models/maritalk.py b/libs/community/langchain_community/chat_models/maritalk.py index ab90de19c88..064fd46fa17 100644 --- a/libs/community/langchain_community/chat_models/maritalk.py +++ b/libs/community/langchain_community/chat_models/maritalk.py @@ -1,3 +1,4 @@ +from http import HTTPStatus from typing import Any, Dict, List, Optional, Union import requests @@ -5,6 +6,32 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import SimpleChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import Field +from requests import Response +from requests.exceptions import HTTPError + + +class MaritalkHTTPError(HTTPError): + def __init__(self, request_obj: Response) -> None: + self.request_obj = request_obj + try: + response_json = request_obj.json() + if "detail" in response_json: + api_message = response_json["detail"] + elif "message" in response_json: + api_message = response_json["message"] + else: + api_message = response_json + except Exception: + api_message = request_obj.text + + self.message = api_message + self.status_code = request_obj.status_code + + def __str__(self) -> str: + status_code_meaning = HTTPStatus(self.status_code).phrase + formatted_message = f"HTTP Error: {self.status_code} - {status_code_meaning}" + formatted_message += f"\nDetail: {self.message}" + return formatted_message class ChatMaritalk(SimpleChatModel): @@ -23,6 +50,14 @@ class ChatMaritalk(SimpleChatModel): api_key: str """Your MariTalk API key.""" + model: str + """Chose one of the available models: + - `sabia-2-medium` + - `sabia-2-small` + - `sabia-2-medium-2024-03-13` + - `sabia-2-small-2024-03-13` + - `maritalk-2024-01-08` (deprecated)""" + temperature: float = Field(default=0.7, gt=0.0, lt=1.0) """Run inference with this temperature. Must be in the closed interval [0.0, 1.0].""" @@ -37,10 +72,6 @@ class ChatMaritalk(SimpleChatModel): """Nucleus sampling parameter controlling the size of the probability mass considered for sampling.""" - system_message_workaround: bool = Field(default=True) - """Whether to include a workaround for system messages - by adding them as a user message.""" - @property def _llm_type(self) -> str: """Identifies the LLM type as 'maritalk'.""" @@ -64,17 +95,13 @@ class ChatMaritalk(SimpleChatModel): for message in messages: if isinstance(message, HumanMessage): - parsed_messages.append({"role": "user", "content": message.content}) + role = "user" elif isinstance(message, AIMessage): - parsed_messages.append( - {"role": "assistant", "content": message.content} - ) - elif isinstance(message, SystemMessage) and self.system_message_workaround: - # Maritalk models do not understand system message. - # #Instead we add these messages as user messages. - parsed_messages.append({"role": "user", "content": message.content}) - parsed_messages.append({"role": "assistant", "content": "ok"}) + role = "assistant" + elif isinstance(message, SystemMessage): + role = "system" + parsed_messages.append({"role": role, "content": message.content}) return parsed_messages def _call( @@ -114,6 +141,7 @@ class ChatMaritalk(SimpleChatModel): data = { "messages": parsed_messages, + "model": self.model, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, @@ -123,10 +151,11 @@ class ChatMaritalk(SimpleChatModel): } response = requests.post(url, json=data, headers=headers) - if response.status_code == 429: - return "Rate limited, please try again soon" - elif response.ok: + + if response.ok: return response.json().get("answer", "No answer found") + else: + raise MaritalkHTTPError(response) except requests.exceptions.RequestException as e: return f"An error occurred: {str(e)}" @@ -144,7 +173,7 @@ class ChatMaritalk(SimpleChatModel): A dictionary of the key configuration parameters. """ return { - "system_message_workaround": self.system_message_workaround, + "model": self.model, "temperature": self.temperature, "top_p": self.top_p, "max_tokens": self.max_tokens,