diff --git a/langchain/llms/aviary.py b/langchain/llms/aviary.py index bd5a3ebd55f..8408406d305 100644 --- a/langchain/llms/aviary.py +++ b/langchain/llms/aviary.py @@ -1,8 +1,10 @@ """Wrapper around Aviary""" -from typing import Any, Dict, List, Mapping, Optional +import dataclasses +import os +from typing import Any, Dict, List, Mapping, Optional, Union, cast import requests -from pydantic import Extra, Field, root_validator +from pydantic import Extra, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -12,6 +14,68 @@ from langchain.utils import get_from_dict_or_env TIMEOUT = 60 +@dataclasses.dataclass +class AviaryBackend: + backend_url: str + bearer: str + + def __post_init__(self) -> None: + self.header = {"Authorization": self.bearer} + + @classmethod + def from_env(cls) -> "AviaryBackend": + aviary_url = os.getenv("AVIARY_URL") + assert aviary_url, "AVIARY_URL must be set" + + aviary_token = os.getenv("AVIARY_TOKEN", "") + + bearer = f"Bearer {aviary_token}" if aviary_token else "" + aviary_url += "/" if not aviary_url.endswith("/") else "" + + return cls(aviary_url, bearer) + + +def get_models() -> List[str]: + """List available models""" + backend = AviaryBackend.from_env() + request_url = backend.backend_url + "-/routes" + response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT) + try: + result = response.json() + except requests.JSONDecodeError as e: + raise RuntimeError( + f"Error decoding JSON from {request_url}. Text response: {response.text}" + ) from e + result = sorted( + [k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k] + ) + return result + + +def get_completions( + model: str, + prompt: str, + use_prompt_format: bool = True, + version: str = "", +) -> Dict[str, Union[str, float, int]]: + """Get completions from Aviary models.""" + + backend = AviaryBackend.from_env() + url = backend.backend_url + model.replace("/", "--") + "/" + version + "query" + response = requests.post( + url, + headers=backend.header, + json={"prompt": prompt, "use_prompt_format": use_prompt_format}, + timeout=TIMEOUT, + ) + try: + return response.json() + except requests.JSONDecodeError as e: + raise RuntimeError( + f"Error decoding JSON from {url}. Text response: {response.text}" + ) from e + + class Aviary(LLM): """Allow you to use an Aviary. @@ -19,33 +83,30 @@ class Aviary(LLM): find out more about aviary at http://github.com/ray-project/aviary - Has no dependencies, since it connects to backend - directly. - To get a list of the models supported on an aviary, follow the instructions on the web site to install the aviary CLI and then use: `aviary models` - You must at least specify the environment - variable or parameter AVIARY_URL. - - You may optionally specify the environment variable - or parameter AVIARY_TOKEN. + AVIARY_URL and AVIARY_TOKEN environement variables must be set. Example: .. code-block:: python from langchain.llms import Aviary - light = Aviary(aviary_url='AVIARY_URL', - model='amazon/LightGPT') - - result = light.predict('How do you make fried rice?') + os.environ["AVIARY_URL"] = "" + os.environ["AVIARY_TOKEN"] = "" + light = Aviary(model='amazon/LightGPT') + output = light('How do you make fried rice?') """ - model: str - aviary_url: str - aviary_token: str = Field("", exclude=True) + model: str = "amazon/LightGPT" + aviary_url: Optional[str] = None + aviary_token: Optional[str] = None + # If True the prompt template for the model will be ignored. + use_prompt_format: bool = True + # API version to use for Aviary + version: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -56,49 +117,35 @@ class Aviary(LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL") - if not aviary_url.endswith("/"): - aviary_url += "/" - values["aviary_url"] = aviary_url - aviary_token = get_from_dict_or_env( - values, "aviary_token", "AVIARY_TOKEN", default="" - ) - values["aviary_token"] = aviary_token + aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN") + + # Set env viarables for aviary sdk + os.environ["AVIARY_URL"] = aviary_url + os.environ["AVIARY_TOKEN"] = aviary_token - aviary_endpoint = aviary_url + "models" - headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {} try: - response = requests.get(aviary_endpoint, headers=headers) - result = response.json() - # Confirm model is available - if values["model"] not in result: - raise ValueError( - f"{aviary_url} does not support model {values['model']}." - ) - + aviary_models = get_models() except requests.exceptions.RequestException as e: raise ValueError(e) + model = values.get("model") + if model and model not in aviary_models: + raise ValueError(f"{aviary_url} does not support model {values['model']}.") + return values @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { + "model_name": self.model, "aviary_url": self.aviary_url, - "aviary_token": self.aviary_token, } @property def _llm_type(self) -> str: """Return type of llm.""" - return "aviary" - - @property - def headers(self) -> Dict[str, str]: - if self.aviary_token: - return {"Authorization": f"Bearer {self.aviary_token}"} - else: - return {} + return f"aviary-{self.model.replace('/', '-')}" def _call( self, @@ -119,19 +166,18 @@ class Aviary(LLM): response = aviary("Tell me a joke.") """ - url = self.aviary_url + "query/" + self.model.replace("/", "--") - response = requests.post( - url, - headers=self.headers, - json={"prompt": prompt}, - timeout=TIMEOUT, + kwargs = {"use_prompt_format": self.use_prompt_format} + if self.version: + kwargs["version"] = self.version + + output = get_completions( + model=self.model, + prompt=prompt, + **kwargs, ) - try: - text = response.json()[self.model]["generated_text"] - except requests.JSONDecodeError as e: - raise ValueError( - f"Error decoding JSON from {url}. Text response: {response.text}", - ) from e + + text = cast(str, output["generated_text"]) if stop: text = enforce_stop_tokens(text, stop) + return text diff --git a/tests/integration_tests/llms/test_aviary.py b/tests/integration_tests/llms/test_aviary.py index d2d67fb3f01..41c95a0c6ec 100644 --- a/tests/integration_tests/llms/test_aviary.py +++ b/tests/integration_tests/llms/test_aviary.py @@ -5,6 +5,7 @@ from langchain.llms.aviary import Aviary def test_aviary_call() -> None: """Test valid call to Anyscale.""" - llm = Aviary(model="test/model") + llm = Aviary() output = llm("Say bar:") + print(f"llm answer:\n{output}") assert isinstance(output, str)