diff --git a/libs/community/langchain_community/llms/deepinfra.py b/libs/community/langchain_community/llms/deepinfra.py index 65911921a2f..6d8deb0caa8 100644 --- a/libs/community/langchain_community/llms/deepinfra.py +++ b/libs/community/langchain_community/llms/deepinfra.py @@ -13,7 +13,7 @@ from langchain_core.utils import get_from_dict_or_env from langchain_community.utilities.requests import Requests -DEFAULT_MODEL_ID = "google/flan-t5-xl" +DEFAULT_MODEL_ID = "meta-llama/Meta-Llama-3-70B-Instruct" class DeepInfra(LLM): @@ -85,7 +85,15 @@ class DeepInfra(LLM): def _handle_status(self, code: int, text: Any) -> None: if code >= 500: - raise Exception(f"DeepInfra Server: Error {code}") + raise Exception(f"DeepInfra Server: Error {text}") + elif code == 401: + raise Exception("DeepInfra Server: Unauthorized") + elif code == 403: + raise Exception("DeepInfra Server: Unauthorized") + elif code == 404: + raise Exception(f"DeepInfra Server: Model not found {self.model_id}") + elif code == 429: + raise Exception("DeepInfra Server: Rate limit exceeded") elif code >= 400: raise ValueError(f"DeepInfra received an invalid payload: {text}") elif code != 200: @@ -150,7 +158,8 @@ class DeepInfra(LLM): response = request.post( url=self._url(), data=self._body(prompt, {**kwargs, "stream": True}) ) - + response_text = response.text + self._handle_body_errors(response_text) self._handle_status(response.status_code, response.text) for line in _parse_stream(response.iter_lines()): chunk = _handle_sse_line(line) @@ -170,6 +179,8 @@ class DeepInfra(LLM): async with request.apost( url=self._url(), data=self._body(prompt, {**kwargs, "stream": True}) ) as response: + response_text = await response.text() + self._handle_body_errors(response_text) self._handle_status(response.status, response.text) async for line in _parse_stream_async(response.content): chunk = _handle_sse_line(line) @@ -178,6 +189,24 @@ class DeepInfra(LLM): await run_manager.on_llm_new_token(chunk.text) yield chunk + def _handle_body_errors(self, body: str) -> None: + """ + Example error response: + data: {"error_type": "validation_error", + "error_message": "ConnectionError: ..."} + """ + if "error" in body: + try: + # Remove data: prefix if present + if body.startswith("data:"): + body = body[len("data:") :] + error_data = json.loads(body) + error_message = error_data.get("error_message", "Unknown error") + + raise Exception(f"DeepInfra Server Error: {error_message}") + except json.JSONDecodeError: + raise Exception(f"DeepInfra Server: {body}") + def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]: for line in rbody: