From d8734ce5adc0086d33859b9332d6c2252092a3d6 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Thu, 10 Nov 2022 18:12:28 +0200 Subject: [PATCH] Add AI21 LLMs (#99) Integrate AI21 /complete API into langchain, to allow access to Jurassic models. --- langchain/llms/ai21.py | 132 ++++++++++++++++++++++ tests/integration_tests/llms/test_ai21.py | 10 ++ 2 files changed, 142 insertions(+) create mode 100644 langchain/llms/ai21.py create mode 100644 tests/integration_tests/llms/test_ai21.py diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py new file mode 100644 index 00000000000..591c3d6b7d9 --- /dev/null +++ b/langchain/llms/ai21.py @@ -0,0 +1,132 @@ +"""Wrapper around AI21 APIs.""" +import os +from typing import Any, Dict, List, Mapping, Optional + +import requests +from pydantic import BaseModel, Extra, root_validator + +from langchain.llms.base import LLM + + +class AI21PenaltyData(BaseModel): + scale: int = 0 + applyToWhitespaces: bool = True + applyToPunctuations: bool = True + applyToNumbers: bool = True + applyToStopwords: bool = True + applyToEmojis: bool = True + + +class AI21(BaseModel, LLM): + """Wrapper around AI21 large language models. + + To use, you should have the environment variable ``AI21_API_KEY`` set with your API key. + + Example: + .. code-block:: python + + from langchain import AI21 + ai21 = AI21(model="j1-jumbo") + """ + + model: str = "j1-jumbo" + """Model name to use.""" + + temperature: float = 0.7 + """What sampling temperature to use.""" + + maxTokens: int = 256 + """The maximum number of tokens to generate in the completion.""" + + minTokens: int = 0 + """The minimum number of tokens to generate in the completion.""" + + topP: float = 1.0 + """Total probability mass of tokens to consider at each step.""" + + presencePenalty: AI21PenaltyData = AI21PenaltyData() + """Penalizes repeated tokens.""" + + countPenalty: AI21PenaltyData = AI21PenaltyData() + """Penalizes repeated tokens according to count.""" + + frequencyPenalty: AI21PenaltyData = AI21PenaltyData() + """Penalizes repeated tokens according to frequency.""" + + numResults: int = 1 + """How many completions to generate for each prompt.""" + + logitBias: Dict[str, float] = None + """Adjust the probability of specific tokens being generated.""" + + ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY") + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key exists in environment.""" + ai21_api_key = values.get("ai21_api_key") + + if ai21_api_key is None or ai21_api_key == "": + raise ValueError( + "Did not find AI21 API key, please add an environment variable" + " `AI21_API_KEY` which contains it, or pass `ai21_api_key`" + " as a named parameter." + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling AI21 API.""" + return { + "temperature": self.temperature, + "maxTokens": self.maxTokens, + "minTokens": self.minTokens, + "topP": self.topP, + "presencePenalty": self.presencePenalty.dict(), + "countPenalty": self.countPenalty.dict(), + "frequencyPenalty": self.frequencyPenalty.dict(), + "numResults": self.numResults, + "logitBias": self.logitBias, + } + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + + def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to AI21's complete endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = ai21("Tell me a joke.") + """ + if stop is None: + stop = [] + response = requests.post( + url=f"https://api.ai21.com/studio/v1/{self.model}/complete", + headers={"Authorization": f"Bearer {self.ai21_api_key}"}, + json={ + "prompt": prompt, + "stopSequences": stop, + **self._default_params, + } + ) + if response.status_code != 200: + optional_detail = response.json().get('error') + raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}') + response = response.json() + return response["completions"][0]["data"]["text"] diff --git a/tests/integration_tests/llms/test_ai21.py b/tests/integration_tests/llms/test_ai21.py new file mode 100644 index 00000000000..3737ee5c6fa --- /dev/null +++ b/tests/integration_tests/llms/test_ai21.py @@ -0,0 +1,10 @@ +"""Test AI21 API wrapper.""" + +from langchain.llms.ai21 import AI21 + + +def test_ai21_call() -> None: + """Test valid call to ai21.""" + llm = AI21(maxTokens=10) + output = llm("Say foo:") + assert isinstance(output, str)