Fix ray-project/Aviary integration (#6607)

- Description: The aviary integration has changed url link. This PR
provide fix for those changes and also it makes providing the input URL
optional to the API (since they can be set via env variables).
  - Issue: N/A
  - Dependencies: N/A
  - Twitter handle: N/A

---------

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
kourosh hakhamaneshi 2023-06-23 14:49:53 -07:00 committed by GitHub
parent dbe1d029ec
commit f6fdabd20b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 56 deletions

View File

@ -1,8 +1,10 @@
"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional
import dataclasses
import os
from typing import Any, Dict, List, Mapping, Optional, Union, cast
import requests
from pydantic import Extra, Field, root_validator
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@ -12,6 +14,68 @@ from langchain.utils import get_from_dict_or_env
TIMEOUT = 60
@dataclasses.dataclass
class AviaryBackend:
backend_url: str
bearer: str
def __post_init__(self) -> None:
self.header = {"Authorization": self.bearer}
@classmethod
def from_env(cls) -> "AviaryBackend":
aviary_url = os.getenv("AVIARY_URL")
assert aviary_url, "AVIARY_URL must be set"
aviary_token = os.getenv("AVIARY_TOKEN", "")
bearer = f"Bearer {aviary_token}" if aviary_token else ""
aviary_url += "/" if not aviary_url.endswith("/") else ""
return cls(aviary_url, bearer)
def get_models() -> List[str]:
"""List available models"""
backend = AviaryBackend.from_env()
request_url = backend.backend_url + "-/routes"
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
try:
result = response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {request_url}. Text response: {response.text}"
) from e
result = sorted(
[k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
)
return result
def get_completions(
model: str,
prompt: str,
use_prompt_format: bool = True,
version: str = "",
) -> Dict[str, Union[str, float, int]]:
"""Get completions from Aviary models."""
backend = AviaryBackend.from_env()
url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
response = requests.post(
url,
headers=backend.header,
json={"prompt": prompt, "use_prompt_format": use_prompt_format},
timeout=TIMEOUT,
)
try:
return response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {url}. Text response: {response.text}"
) from e
class Aviary(LLM):
"""Allow you to use an Aviary.
@ -19,33 +83,30 @@ class Aviary(LLM):
find out more about aviary at
http://github.com/ray-project/aviary
Has no dependencies, since it connects to backend
directly.
To get a list of the models supported on an
aviary, follow the instructions on the web site to
install the aviary CLI and then use:
`aviary models`
You must at least specify the environment
variable or parameter AVIARY_URL.
You may optionally specify the environment variable
or parameter AVIARY_TOKEN.
AVIARY_URL and AVIARY_TOKEN environement variables must be set.
Example:
.. code-block:: python
from langchain.llms import Aviary
light = Aviary(aviary_url='AVIARY_URL',
model='amazon/LightGPT')
result = light.predict('How do you make fried rice?')
os.environ["AVIARY_URL"] = "<URL>"
os.environ["AVIARY_TOKEN"] = "<TOKEN>"
light = Aviary(model='amazon/LightGPT')
output = light('How do you make fried rice?')
"""
model: str
aviary_url: str
aviary_token: str = Field("", exclude=True)
model: str = "amazon/LightGPT"
aviary_url: Optional[str] = None
aviary_token: Optional[str] = None
# If True the prompt template for the model will be ignored.
use_prompt_format: bool = True
# API version to use for Aviary
version: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
@ -56,49 +117,35 @@ class Aviary(LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
if not aviary_url.endswith("/"):
aviary_url += "/"
values["aviary_url"] = aviary_url
aviary_token = get_from_dict_or_env(
values, "aviary_token", "AVIARY_TOKEN", default=""
)
values["aviary_token"] = aviary_token
aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")
# Set env viarables for aviary sdk
os.environ["AVIARY_URL"] = aviary_url
os.environ["AVIARY_TOKEN"] = aviary_token
aviary_endpoint = aviary_url + "models"
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
try:
response = requests.get(aviary_endpoint, headers=headers)
result = response.json()
# Confirm model is available
if values["model"] not in result:
raise ValueError(
f"{aviary_url} does not support model {values['model']}."
)
aviary_models = get_models()
except requests.exceptions.RequestException as e:
raise ValueError(e)
model = values.get("model")
if model and model not in aviary_models:
raise ValueError(f"{aviary_url} does not support model {values['model']}.")
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model,
"aviary_url": self.aviary_url,
"aviary_token": self.aviary_token,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "aviary"
@property
def headers(self) -> Dict[str, str]:
if self.aviary_token:
return {"Authorization": f"Bearer {self.aviary_token}"}
else:
return {}
return f"aviary-{self.model.replace('/', '-')}"
def _call(
self,
@ -119,19 +166,18 @@ class Aviary(LLM):
response = aviary("Tell me a joke.")
"""
url = self.aviary_url + "query/" + self.model.replace("/", "--")
response = requests.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=TIMEOUT,
kwargs = {"use_prompt_format": self.use_prompt_format}
if self.version:
kwargs["version"] = self.version
output = get_completions(
model=self.model,
prompt=prompt,
**kwargs,
)
try:
text = response.json()[self.model]["generated_text"]
except requests.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON from {url}. Text response: {response.text}",
) from e
text = cast(str, output["generated_text"])
if stop:
text = enforce_stop_tokens(text, stop)
return text

View File

@ -5,6 +5,7 @@ from langchain.llms.aviary import Aviary
def test_aviary_call() -> None:
"""Test valid call to Anyscale."""
llm = Aviary(model="test/model")
llm = Aviary()
output = llm("Say bar:")
print(f"llm answer:\n{output}")
assert isinstance(output, str)