mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
community[patch]: fix deepinfra inference (#22680)
This PR includes: 1. Update of default model to LLama3. 2. Handle some 400x errors with more user friendly error messages. 3. Handle user errors.
This commit is contained in:
parent
cb79e80b0b
commit
f0f4532579
@ -13,7 +13,7 @@ from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.utilities.requests import Requests
|
||||
|
||||
DEFAULT_MODEL_ID = "google/flan-t5-xl"
|
||||
DEFAULT_MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||
|
||||
|
||||
class DeepInfra(LLM):
|
||||
@ -85,7 +85,15 @@ class DeepInfra(LLM):
|
||||
|
||||
def _handle_status(self, code: int, text: Any) -> None:
|
||||
if code >= 500:
|
||||
raise Exception(f"DeepInfra Server: Error {code}")
|
||||
raise Exception(f"DeepInfra Server: Error {text}")
|
||||
elif code == 401:
|
||||
raise Exception("DeepInfra Server: Unauthorized")
|
||||
elif code == 403:
|
||||
raise Exception("DeepInfra Server: Unauthorized")
|
||||
elif code == 404:
|
||||
raise Exception(f"DeepInfra Server: Model not found {self.model_id}")
|
||||
elif code == 429:
|
||||
raise Exception("DeepInfra Server: Rate limit exceeded")
|
||||
elif code >= 400:
|
||||
raise ValueError(f"DeepInfra received an invalid payload: {text}")
|
||||
elif code != 200:
|
||||
@ -150,7 +158,8 @@ class DeepInfra(LLM):
|
||||
response = request.post(
|
||||
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
|
||||
)
|
||||
|
||||
response_text = response.text
|
||||
self._handle_body_errors(response_text)
|
||||
self._handle_status(response.status_code, response.text)
|
||||
for line in _parse_stream(response.iter_lines()):
|
||||
chunk = _handle_sse_line(line)
|
||||
@ -170,6 +179,8 @@ class DeepInfra(LLM):
|
||||
async with request.apost(
|
||||
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
|
||||
) as response:
|
||||
response_text = await response.text()
|
||||
self._handle_body_errors(response_text)
|
||||
self._handle_status(response.status, response.text)
|
||||
async for line in _parse_stream_async(response.content):
|
||||
chunk = _handle_sse_line(line)
|
||||
@ -178,6 +189,24 @@ class DeepInfra(LLM):
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
||||
yield chunk
|
||||
|
||||
def _handle_body_errors(self, body: str) -> None:
|
||||
"""
|
||||
Example error response:
|
||||
data: {"error_type": "validation_error",
|
||||
"error_message": "ConnectionError: ..."}
|
||||
"""
|
||||
if "error" in body:
|
||||
try:
|
||||
# Remove data: prefix if present
|
||||
if body.startswith("data:"):
|
||||
body = body[len("data:") :]
|
||||
error_data = json.loads(body)
|
||||
error_message = error_data.get("error_message", "Unknown error")
|
||||
|
||||
raise Exception(f"DeepInfra Server Error: {error_message}")
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"DeepInfra Server: {body}")
|
||||
|
||||
|
||||
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||
for line in rbody:
|
||||
|
Loading…
Reference in New Issue
Block a user