mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
community[patch]: fix deepinfra inference (#22680)
This PR includes: 1. Update of default model to LLama3. 2. Handle some 400x errors with more user friendly error messages. 3. Handle user errors.
This commit is contained in:
parent
cb79e80b0b
commit
f0f4532579
@ -13,7 +13,7 @@ from langchain_core.utils import get_from_dict_or_env
|
|||||||
|
|
||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
DEFAULT_MODEL_ID = "google/flan-t5-xl"
|
DEFAULT_MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
class DeepInfra(LLM):
|
class DeepInfra(LLM):
|
||||||
@ -85,7 +85,15 @@ class DeepInfra(LLM):
|
|||||||
|
|
||||||
def _handle_status(self, code: int, text: Any) -> None:
|
def _handle_status(self, code: int, text: Any) -> None:
|
||||||
if code >= 500:
|
if code >= 500:
|
||||||
raise Exception(f"DeepInfra Server: Error {code}")
|
raise Exception(f"DeepInfra Server: Error {text}")
|
||||||
|
elif code == 401:
|
||||||
|
raise Exception("DeepInfra Server: Unauthorized")
|
||||||
|
elif code == 403:
|
||||||
|
raise Exception("DeepInfra Server: Unauthorized")
|
||||||
|
elif code == 404:
|
||||||
|
raise Exception(f"DeepInfra Server: Model not found {self.model_id}")
|
||||||
|
elif code == 429:
|
||||||
|
raise Exception("DeepInfra Server: Rate limit exceeded")
|
||||||
elif code >= 400:
|
elif code >= 400:
|
||||||
raise ValueError(f"DeepInfra received an invalid payload: {text}")
|
raise ValueError(f"DeepInfra received an invalid payload: {text}")
|
||||||
elif code != 200:
|
elif code != 200:
|
||||||
@ -150,7 +158,8 @@ class DeepInfra(LLM):
|
|||||||
response = request.post(
|
response = request.post(
|
||||||
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
|
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
|
||||||
)
|
)
|
||||||
|
response_text = response.text
|
||||||
|
self._handle_body_errors(response_text)
|
||||||
self._handle_status(response.status_code, response.text)
|
self._handle_status(response.status_code, response.text)
|
||||||
for line in _parse_stream(response.iter_lines()):
|
for line in _parse_stream(response.iter_lines()):
|
||||||
chunk = _handle_sse_line(line)
|
chunk = _handle_sse_line(line)
|
||||||
@ -170,6 +179,8 @@ class DeepInfra(LLM):
|
|||||||
async with request.apost(
|
async with request.apost(
|
||||||
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
|
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
|
||||||
) as response:
|
) as response:
|
||||||
|
response_text = await response.text()
|
||||||
|
self._handle_body_errors(response_text)
|
||||||
self._handle_status(response.status, response.text)
|
self._handle_status(response.status, response.text)
|
||||||
async for line in _parse_stream_async(response.content):
|
async for line in _parse_stream_async(response.content):
|
||||||
chunk = _handle_sse_line(line)
|
chunk = _handle_sse_line(line)
|
||||||
@ -178,6 +189,24 @@ class DeepInfra(LLM):
|
|||||||
await run_manager.on_llm_new_token(chunk.text)
|
await run_manager.on_llm_new_token(chunk.text)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
def _handle_body_errors(self, body: str) -> None:
|
||||||
|
"""
|
||||||
|
Example error response:
|
||||||
|
data: {"error_type": "validation_error",
|
||||||
|
"error_message": "ConnectionError: ..."}
|
||||||
|
"""
|
||||||
|
if "error" in body:
|
||||||
|
try:
|
||||||
|
# Remove data: prefix if present
|
||||||
|
if body.startswith("data:"):
|
||||||
|
body = body[len("data:") :]
|
||||||
|
error_data = json.loads(body)
|
||||||
|
error_message = error_data.get("error_message", "Unknown error")
|
||||||
|
|
||||||
|
raise Exception(f"DeepInfra Server Error: {error_message}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise Exception(f"DeepInfra Server: {body}")
|
||||||
|
|
||||||
|
|
||||||
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
|
||||||
for line in rbody:
|
for line in rbody:
|
||||||
|
Loading…
Reference in New Issue
Block a user