From 4d4cff0530a2ae6c7f5eca2cc7362310725dd2f2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 17 Jan 2023 22:28:55 -0800 Subject: [PATCH] Harrison/cohere experimental (#638) Co-authored-by: inyourhead <44607279+xettrisomeman@users.noreply.github.com> --- langchain/llms/ai21.py | 12 +++++++++++- tests/integration_tests/llms/test_ai21.py | 7 +++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 77a9300d70a..efbb8889139 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -64,6 +64,9 @@ class AI21(LLM, BaseModel): ai21_api_key: Optional[str] = None + base_url: Optional[str] = None + """Base url to use, if None decides based on model name.""" + class Config: """Configuration for this pydantic object.""" @@ -118,8 +121,15 @@ class AI21(LLM, BaseModel): """ if stop is None: stop = [] + if self.base_url is not None: + base_url = self.base_url + else: + if self.model in ("j1-grande-instruct",): + base_url = "https://api.ai21.com/studio/v1/experimental" + else: + base_url = "https://api.ai21.com/studio/v1" response = requests.post( - url=f"https://api.ai21.com/studio/v1/{self.model}/complete", + url=f"{base_url}/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key}"}, json={"prompt": prompt, "stopSequences": stop, **self._default_params}, ) diff --git a/tests/integration_tests/llms/test_ai21.py b/tests/integration_tests/llms/test_ai21.py index 16a8716d6e8..6e56e52694f 100644 --- a/tests/integration_tests/llms/test_ai21.py +++ b/tests/integration_tests/llms/test_ai21.py @@ -13,6 +13,13 @@ def test_ai21_call() -> None: assert isinstance(output, str) +def test_ai21_call_experimental() -> None: + """Test valid call to ai21 with an experimental model.""" + llm = AI21(maxTokens=10, model="j1-grande-instruct") + output = llm("Say foo:") + assert isinstance(output, str) + + def test_saving_loading_llm(tmp_path: Path) -> None: """Test saving/loading an AI21 LLM.""" llm = AI21(maxTokens=10)