mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 02:33:19 +00:00
Fix ray-project/Aviary integration (#6607)
- Description: The aviary integration has changed url link. This PR provide fix for those changes and also it makes providing the input URL optional to the API (since they can be set via env variables). - Issue: N/A - Dependencies: N/A - Twitter handle: N/A --------- Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
parent
dbe1d029ec
commit
f6fdabd20b
@ -1,8 +1,10 @@
|
|||||||
"""Wrapper around Aviary"""
|
"""Wrapper around Aviary"""
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
import dataclasses
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional, Union, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -12,6 +14,68 @@ from langchain.utils import get_from_dict_or_env
|
|||||||
TIMEOUT = 60
|
TIMEOUT = 60
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AviaryBackend:
|
||||||
|
backend_url: str
|
||||||
|
bearer: str
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.header = {"Authorization": self.bearer}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "AviaryBackend":
|
||||||
|
aviary_url = os.getenv("AVIARY_URL")
|
||||||
|
assert aviary_url, "AVIARY_URL must be set"
|
||||||
|
|
||||||
|
aviary_token = os.getenv("AVIARY_TOKEN", "")
|
||||||
|
|
||||||
|
bearer = f"Bearer {aviary_token}" if aviary_token else ""
|
||||||
|
aviary_url += "/" if not aviary_url.endswith("/") else ""
|
||||||
|
|
||||||
|
return cls(aviary_url, bearer)
|
||||||
|
|
||||||
|
|
||||||
|
def get_models() -> List[str]:
|
||||||
|
"""List available models"""
|
||||||
|
backend = AviaryBackend.from_env()
|
||||||
|
request_url = backend.backend_url + "-/routes"
|
||||||
|
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
|
||||||
|
try:
|
||||||
|
result = response.json()
|
||||||
|
except requests.JSONDecodeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Error decoding JSON from {request_url}. Text response: {response.text}"
|
||||||
|
) from e
|
||||||
|
result = sorted(
|
||||||
|
[k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_completions(
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
use_prompt_format: bool = True,
|
||||||
|
version: str = "",
|
||||||
|
) -> Dict[str, Union[str, float, int]]:
|
||||||
|
"""Get completions from Aviary models."""
|
||||||
|
|
||||||
|
backend = AviaryBackend.from_env()
|
||||||
|
url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=backend.header,
|
||||||
|
json={"prompt": prompt, "use_prompt_format": use_prompt_format},
|
||||||
|
timeout=TIMEOUT,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return response.json()
|
||||||
|
except requests.JSONDecodeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Error decoding JSON from {url}. Text response: {response.text}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class Aviary(LLM):
|
class Aviary(LLM):
|
||||||
"""Allow you to use an Aviary.
|
"""Allow you to use an Aviary.
|
||||||
|
|
||||||
@ -19,33 +83,30 @@ class Aviary(LLM):
|
|||||||
find out more about aviary at
|
find out more about aviary at
|
||||||
http://github.com/ray-project/aviary
|
http://github.com/ray-project/aviary
|
||||||
|
|
||||||
Has no dependencies, since it connects to backend
|
|
||||||
directly.
|
|
||||||
|
|
||||||
To get a list of the models supported on an
|
To get a list of the models supported on an
|
||||||
aviary, follow the instructions on the web site to
|
aviary, follow the instructions on the web site to
|
||||||
install the aviary CLI and then use:
|
install the aviary CLI and then use:
|
||||||
`aviary models`
|
`aviary models`
|
||||||
|
|
||||||
You must at least specify the environment
|
AVIARY_URL and AVIARY_TOKEN environement variables must be set.
|
||||||
variable or parameter AVIARY_URL.
|
|
||||||
|
|
||||||
You may optionally specify the environment variable
|
|
||||||
or parameter AVIARY_TOKEN.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import Aviary
|
from langchain.llms import Aviary
|
||||||
light = Aviary(aviary_url='AVIARY_URL',
|
os.environ["AVIARY_URL"] = "<URL>"
|
||||||
model='amazon/LightGPT')
|
os.environ["AVIARY_TOKEN"] = "<TOKEN>"
|
||||||
|
light = Aviary(model='amazon/LightGPT')
|
||||||
result = light.predict('How do you make fried rice?')
|
output = light('How do you make fried rice?')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str = "amazon/LightGPT"
|
||||||
aviary_url: str
|
aviary_url: Optional[str] = None
|
||||||
aviary_token: str = Field("", exclude=True)
|
aviary_token: Optional[str] = None
|
||||||
|
# If True the prompt template for the model will be ignored.
|
||||||
|
use_prompt_format: bool = True
|
||||||
|
# API version to use for Aviary
|
||||||
|
version: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -56,49 +117,35 @@ class Aviary(LLM):
|
|||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
|
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
|
||||||
if not aviary_url.endswith("/"):
|
aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")
|
||||||
aviary_url += "/"
|
|
||||||
values["aviary_url"] = aviary_url
|
# Set env viarables for aviary sdk
|
||||||
aviary_token = get_from_dict_or_env(
|
os.environ["AVIARY_URL"] = aviary_url
|
||||||
values, "aviary_token", "AVIARY_TOKEN", default=""
|
os.environ["AVIARY_TOKEN"] = aviary_token
|
||||||
)
|
|
||||||
values["aviary_token"] = aviary_token
|
|
||||||
|
|
||||||
aviary_endpoint = aviary_url + "models"
|
|
||||||
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(aviary_endpoint, headers=headers)
|
aviary_models = get_models()
|
||||||
result = response.json()
|
|
||||||
# Confirm model is available
|
|
||||||
if values["model"] not in result:
|
|
||||||
raise ValueError(
|
|
||||||
f"{aviary_url} does not support model {values['model']}."
|
|
||||||
)
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
|
|
||||||
|
model = values.get("model")
|
||||||
|
if model and model not in aviary_models:
|
||||||
|
raise ValueError(f"{aviary_url} does not support model {values['model']}.")
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {
|
return {
|
||||||
|
"model_name": self.model,
|
||||||
"aviary_url": self.aviary_url,
|
"aviary_url": self.aviary_url,
|
||||||
"aviary_token": self.aviary_token,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "aviary"
|
return f"aviary-{self.model.replace('/', '-')}"
|
||||||
|
|
||||||
@property
|
|
||||||
def headers(self) -> Dict[str, str]:
|
|
||||||
if self.aviary_token:
|
|
||||||
return {"Authorization": f"Bearer {self.aviary_token}"}
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -119,19 +166,18 @@ class Aviary(LLM):
|
|||||||
|
|
||||||
response = aviary("Tell me a joke.")
|
response = aviary("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
url = self.aviary_url + "query/" + self.model.replace("/", "--")
|
kwargs = {"use_prompt_format": self.use_prompt_format}
|
||||||
response = requests.post(
|
if self.version:
|
||||||
url,
|
kwargs["version"] = self.version
|
||||||
headers=self.headers,
|
|
||||||
json={"prompt": prompt},
|
output = get_completions(
|
||||||
timeout=TIMEOUT,
|
model=self.model,
|
||||||
|
prompt=prompt,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
text = response.json()[self.model]["generated_text"]
|
text = cast(str, output["generated_text"])
|
||||||
except requests.JSONDecodeError as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Error decoding JSON from {url}. Text response: {response.text}",
|
|
||||||
) from e
|
|
||||||
if stop:
|
if stop:
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
@ -5,6 +5,7 @@ from langchain.llms.aviary import Aviary
|
|||||||
|
|
||||||
def test_aviary_call() -> None:
|
def test_aviary_call() -> None:
|
||||||
"""Test valid call to Anyscale."""
|
"""Test valid call to Anyscale."""
|
||||||
llm = Aviary(model="test/model")
|
llm = Aviary()
|
||||||
output = llm("Say bar:")
|
output = llm("Say bar:")
|
||||||
|
print(f"llm answer:\n{output}")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
Loading…
Reference in New Issue
Block a user