community[patch]: fix deepinfra inference (#22680)

This PR includes:

1. Update of default model to LLama3.
2. Handle some 400x errors with more user friendly error messages.
3. Handle user errors.
This commit is contained in:
Oguz Vuruskaner 2024-06-10 23:55:55 +03:00 committed by GitHub
parent cb79e80b0b
commit f0f4532579
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,7 +13,7 @@ from langchain_core.utils import get_from_dict_or_env
from langchain_community.utilities.requests import Requests
DEFAULT_MODEL_ID = "google/flan-t5-xl"
DEFAULT_MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
class DeepInfra(LLM):
@ -85,7 +85,15 @@ class DeepInfra(LLM):
def _handle_status(self, code: int, text: Any) -> None:
if code >= 500:
raise Exception(f"DeepInfra Server: Error {code}")
raise Exception(f"DeepInfra Server: Error {text}")
elif code == 401:
raise Exception("DeepInfra Server: Unauthorized")
elif code == 403:
raise Exception("DeepInfra Server: Unauthorized")
elif code == 404:
raise Exception(f"DeepInfra Server: Model not found {self.model_id}")
elif code == 429:
raise Exception("DeepInfra Server: Rate limit exceeded")
elif code >= 400:
raise ValueError(f"DeepInfra received an invalid payload: {text}")
elif code != 200:
@ -150,7 +158,8 @@ class DeepInfra(LLM):
response = request.post(
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
)
response_text = response.text
self._handle_body_errors(response_text)
self._handle_status(response.status_code, response.text)
for line in _parse_stream(response.iter_lines()):
chunk = _handle_sse_line(line)
@ -170,6 +179,8 @@ class DeepInfra(LLM):
async with request.apost(
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
) as response:
response_text = await response.text()
self._handle_body_errors(response_text)
self._handle_status(response.status, response.text)
async for line in _parse_stream_async(response.content):
chunk = _handle_sse_line(line)
@ -178,6 +189,24 @@ class DeepInfra(LLM):
await run_manager.on_llm_new_token(chunk.text)
yield chunk
def _handle_body_errors(self, body: str) -> None:
"""
Example error response:
data: {"error_type": "validation_error",
"error_message": "ConnectionError: ..."}
"""
if "error" in body:
try:
# Remove data: prefix if present
if body.startswith("data:"):
body = body[len("data:") :]
error_data = json.loads(body)
error_message = error_data.get("error_message", "Unknown error")
raise Exception(f"DeepInfra Server Error: {error_message}")
except json.JSONDecodeError:
raise Exception(f"DeepInfra Server: {body}")
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody: