mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
community: Add model argument for maritalk models and better error handling (#19187)
This commit is contained in:
parent
ff94f86ce1
commit
e64cf1aba4
@ -65,6 +65,7 @@
|
|||||||
"from langchain_core.output_parsers import StrOutputParser\n",
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||||
"\n",
|
"\n",
|
||||||
"llm = ChatMaritalk(\n",
|
"llm = ChatMaritalk(\n",
|
||||||
|
" model=\"sabia-2-medium\", # Available models: sabia-2-small and sabia-2-medium\n",
|
||||||
" api_key=\"\", # Insert your API key here\n",
|
" api_key=\"\", # Insert your API key here\n",
|
||||||
" temperature=0.7,\n",
|
" temperature=0.7,\n",
|
||||||
" max_tokens=100,\n",
|
" max_tokens=100,\n",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -5,6 +6,32 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
from requests import Response
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
|
||||||
|
class MaritalkHTTPError(HTTPError):
|
||||||
|
def __init__(self, request_obj: Response) -> None:
|
||||||
|
self.request_obj = request_obj
|
||||||
|
try:
|
||||||
|
response_json = request_obj.json()
|
||||||
|
if "detail" in response_json:
|
||||||
|
api_message = response_json["detail"]
|
||||||
|
elif "message" in response_json:
|
||||||
|
api_message = response_json["message"]
|
||||||
|
else:
|
||||||
|
api_message = response_json
|
||||||
|
except Exception:
|
||||||
|
api_message = request_obj.text
|
||||||
|
|
||||||
|
self.message = api_message
|
||||||
|
self.status_code = request_obj.status_code
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
status_code_meaning = HTTPStatus(self.status_code).phrase
|
||||||
|
formatted_message = f"HTTP Error: {self.status_code} - {status_code_meaning}"
|
||||||
|
formatted_message += f"\nDetail: {self.message}"
|
||||||
|
return formatted_message
|
||||||
|
|
||||||
|
|
||||||
class ChatMaritalk(SimpleChatModel):
|
class ChatMaritalk(SimpleChatModel):
|
||||||
@ -23,6 +50,14 @@ class ChatMaritalk(SimpleChatModel):
|
|||||||
api_key: str
|
api_key: str
|
||||||
"""Your MariTalk API key."""
|
"""Your MariTalk API key."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
"""Chose one of the available models:
|
||||||
|
- `sabia-2-medium`
|
||||||
|
- `sabia-2-small`
|
||||||
|
- `sabia-2-medium-2024-03-13`
|
||||||
|
- `sabia-2-small-2024-03-13`
|
||||||
|
- `maritalk-2024-01-08` (deprecated)"""
|
||||||
|
|
||||||
temperature: float = Field(default=0.7, gt=0.0, lt=1.0)
|
temperature: float = Field(default=0.7, gt=0.0, lt=1.0)
|
||||||
"""Run inference with this temperature.
|
"""Run inference with this temperature.
|
||||||
Must be in the closed interval [0.0, 1.0]."""
|
Must be in the closed interval [0.0, 1.0]."""
|
||||||
@ -37,10 +72,6 @@ class ChatMaritalk(SimpleChatModel):
|
|||||||
"""Nucleus sampling parameter controlling the size of
|
"""Nucleus sampling parameter controlling the size of
|
||||||
the probability mass considered for sampling."""
|
the probability mass considered for sampling."""
|
||||||
|
|
||||||
system_message_workaround: bool = Field(default=True)
|
|
||||||
"""Whether to include a workaround for system messages
|
|
||||||
by adding them as a user message."""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Identifies the LLM type as 'maritalk'."""
|
"""Identifies the LLM type as 'maritalk'."""
|
||||||
@ -64,17 +95,13 @@ class ChatMaritalk(SimpleChatModel):
|
|||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message, HumanMessage):
|
if isinstance(message, HumanMessage):
|
||||||
parsed_messages.append({"role": "user", "content": message.content})
|
role = "user"
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
parsed_messages.append(
|
role = "assistant"
|
||||||
{"role": "assistant", "content": message.content}
|
elif isinstance(message, SystemMessage):
|
||||||
)
|
role = "system"
|
||||||
elif isinstance(message, SystemMessage) and self.system_message_workaround:
|
|
||||||
# Maritalk models do not understand system message.
|
|
||||||
# #Instead we add these messages as user messages.
|
|
||||||
parsed_messages.append({"role": "user", "content": message.content})
|
|
||||||
parsed_messages.append({"role": "assistant", "content": "ok"})
|
|
||||||
|
|
||||||
|
parsed_messages.append({"role": role, "content": message.content})
|
||||||
return parsed_messages
|
return parsed_messages
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
@ -114,6 +141,7 @@ class ChatMaritalk(SimpleChatModel):
|
|||||||
|
|
||||||
data = {
|
data = {
|
||||||
"messages": parsed_messages,
|
"messages": parsed_messages,
|
||||||
|
"model": self.model,
|
||||||
"do_sample": self.do_sample,
|
"do_sample": self.do_sample,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
@ -123,10 +151,11 @@ class ChatMaritalk(SimpleChatModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(url, json=data, headers=headers)
|
response = requests.post(url, json=data, headers=headers)
|
||||||
if response.status_code == 429:
|
|
||||||
return "Rate limited, please try again soon"
|
if response.ok:
|
||||||
elif response.ok:
|
|
||||||
return response.json().get("answer", "No answer found")
|
return response.json().get("answer", "No answer found")
|
||||||
|
else:
|
||||||
|
raise MaritalkHTTPError(response)
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
return f"An error occurred: {str(e)}"
|
return f"An error occurred: {str(e)}"
|
||||||
@ -144,7 +173,7 @@ class ChatMaritalk(SimpleChatModel):
|
|||||||
A dictionary of the key configuration parameters.
|
A dictionary of the key configuration parameters.
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"system_message_workaround": self.system_message_workaround,
|
"model": self.model,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user