Improvements to llm/deepinfra (#10846)

- replace `requests` package with `langchain.requests`
- add `_acall` support
- add `_stream` and `_astream`
- freshen up the documentation a bit
- update vendor doc
This commit is contained in:
Iskren Ivov Chernev
2023-10-24 19:54:23 +03:00
committed by GitHub
parent f09f82541b
commit d5d7ba582a
4 changed files with 252 additions and 78 deletions

View File

@@ -1,11 +1,15 @@
from typing import Any, Dict, List, Mapping, Optional
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
import requests
import aiohttp
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, GenerationChunk
from langchain.pydantic_v1 import Extra, root_validator
from langchain.utilities.requests import Requests
from langchain.utils import get_from_dict_or_env
DEFAULT_MODEL_ID = "google/flan-t5-xl"
@@ -14,9 +18,9 @@ DEFAULT_MODEL_ID = "google/flan-t5-xl"
class DeepInfra(LLM):
"""DeepInfra models.
To use, you should have the ``requests`` python package installed, and the
environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
To use, you should have the environment variable ``DEEPINFRA_API_TOKEN``
set with your API token, or pass it as a named parameter to the
constructor.
Only supports `text-generation` and `text2text-generation` for now.
@@ -29,7 +33,7 @@ class DeepInfra(LLM):
"""
model_id: str = DEFAULT_MODEL_ID
model_kwargs: Optional[dict] = None
model_kwargs: Optional[Dict] = None
deepinfra_api_token: Optional[str] = None
@@ -60,6 +64,35 @@ class DeepInfra(LLM):
"""Return type of llm."""
return "deepinfra"
def _url(self) -> str:
return f"https://api.deepinfra.com/v1/inference/{self.model_id}"
def _headers(self) -> Dict:
return {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
def _body(self, prompt: str, kwargs: Any) -> Dict:
model_kwargs = self.model_kwargs or {}
model_kwargs = {**model_kwargs, **kwargs}
return {
"input": prompt,
**model_kwargs,
}
def _handle_status(self, code: int, text: Any) -> None:
if code >= 500:
raise Exception(f"DeepInfra Server: Error {code}")
elif code >= 400:
raise ValueError(f"DeepInfra received an invalid payload: {text}")
elif code != 200:
raise Exception(
f"DeepInfra returned an unexpected response with status "
f"{code}: {text}"
)
def _call(
self,
prompt: str,
@@ -81,38 +114,105 @@ class DeepInfra(LLM):
response = di("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
# HTTP headers for authorization
headers = {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
try:
res = requests.post(
f"https://api.deepinfra.com/v1/inference/{self.model_id}",
headers=headers,
json={"input": prompt, **_model_kwargs},
)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
request = Requests(headers=self._headers())
response = request.post(url=self._url(), data=self._body(prompt, kwargs))
if res.status_code != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (res.status_code, res.text)
)
try:
t = res.json()
text = t["results"][0]["generated_text"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {res.text}"
)
self._handle_status(response.status_code, response.text)
data = response.json()
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text
return data["results"][0]["generated_text"]
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
request = Requests(headers=self._headers())
async with request.apost(
url=self._url(), data=self._body(prompt, kwargs)
) as response:
self._handle_status(response.status, response.text)
data = await response.json()
return data["results"][0]["generated_text"]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
request = Requests(headers=self._headers())
response = request.post(
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
)
self._handle_status(response.status_code, response.text)
for line in _parse_stream(response.iter_lines()):
chunk = _handle_sse_line(line)
if chunk:
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
request = Requests(headers=self._headers())
async with request.apost(
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
) as response:
self._handle_status(response.status, response.text)
async for line in _parse_stream_async(response.content):
chunk = _handle_sse_line(line)
if chunk:
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = _parse_stream_helper(line)
if _line is not None:
yield _line
async def _parse_stream_async(rbody: aiohttp.StreamReader) -> AsyncIterator[str]:
async for line in rbody:
_line = _parse_stream_helper(line)
if _line is not None:
yield _line
def _parse_stream_helper(line: bytes) -> Optional[str]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
else:
return line.decode("utf-8")
return None
def _handle_sse_line(line: str) -> Optional[GenerationChunk]:
try:
obj = json.loads(line)
return GenerationChunk(
text=obj.get("token", {}).get("text"),
)
except Exception:
return None

View File

@@ -1,10 +1,36 @@
"""Test DeepInfra API wrapper."""
import pytest
from langchain.llms.deepinfra import DeepInfra
def test_deepinfra_call() -> None:
"""Test valid call to DeepInfra."""
llm = DeepInfra(model_id="google/flan-t5-small")
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
output = llm("What is 2 + 2?")
assert isinstance(output, str)
@pytest.mark.asyncio
async def test_deepinfra_acall() -> None:
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
output = await llm.apredict("What is 2 + 2?")
assert llm._llm_type == "deepinfra"
assert isinstance(output, str)
def test_deepinfra_stream() -> None:
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
num_chunks = 0
for chunk in llm.stream("[INST] Hello [/INST] "):
num_chunks += 1
assert num_chunks > 0
@pytest.mark.asyncio
async def test_deepinfra_astream() -> None:
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
num_chunks = 0
async for chunk in llm.astream("[INST] Hello [/INST] "):
num_chunks += 1
assert num_chunks > 0