mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
Harrison/cohere experimental (#638)
Co-authored-by: inyourhead <44607279+xettrisomeman@users.noreply.github.com>
This commit is contained in:
parent
5c97f70bf1
commit
4d4cff0530
@ -64,6 +64,9 @@ class AI21(LLM, BaseModel):
|
|||||||
|
|
||||||
ai21_api_key: Optional[str] = None
|
ai21_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
"""Base url to use, if None decides based on model name."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -118,8 +121,15 @@ class AI21(LLM, BaseModel):
|
|||||||
"""
|
"""
|
||||||
if stop is None:
|
if stop is None:
|
||||||
stop = []
|
stop = []
|
||||||
|
if self.base_url is not None:
|
||||||
|
base_url = self.base_url
|
||||||
|
else:
|
||||||
|
if self.model in ("j1-grande-instruct",):
|
||||||
|
base_url = "https://api.ai21.com/studio/v1/experimental"
|
||||||
|
else:
|
||||||
|
base_url = "https://api.ai21.com/studio/v1"
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=f"https://api.ai21.com/studio/v1/{self.model}/complete",
|
url=f"{base_url}/{self.model}/complete",
|
||||||
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
|
||||||
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
|
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
|
||||||
)
|
)
|
||||||
|
@ -13,6 +13,13 @@ def test_ai21_call() -> None:
|
|||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ai21_call_experimental() -> None:
|
||||||
|
"""Test valid call to ai21 with an experimental model."""
|
||||||
|
llm = AI21(maxTokens=10, model="j1-grande-instruct")
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
"""Test saving/loading an AI21 LLM."""
|
"""Test saving/loading an AI21 LLM."""
|
||||||
llm = AI21(maxTokens=10)
|
llm = AI21(maxTokens=10)
|
||||||
|
Loading…
Reference in New Issue
Block a user