mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
ai21[patch]: AI21 Labs bump SDK version (#19114)
Description: Added support AI21 SDK version 2.1.2 Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
edf9d1c905
commit
21c45475c5
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from ai21 import AI21Client
|
from ai21 import AI21Client
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
@ -12,7 +12,7 @@ class AI21Base(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
client: AI21Client = Field(default=None)
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
api_key: Optional[SecretStr] = None
|
api_key: Optional[SecretStr] = None
|
||||||
api_host: Optional[str] = None
|
api_host: Optional[str] = None
|
||||||
timeout_sec: Optional[float] = None
|
timeout_sec: Optional[float] = None
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional, Tuple, cast
|
from typing import Any, List, Mapping, Optional, Tuple, cast
|
||||||
|
|
||||||
from ai21.models import ChatMessage, Penalty, RoleType
|
from ai21.models import ChatMessage, RoleType
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -109,13 +109,13 @@ class ChatAI21(BaseChatModel, AI21Base):
|
|||||||
top_k_return: int = 0
|
top_k_return: int = 0
|
||||||
"""The number of top-scoring tokens to consider for each generation step."""
|
"""The number of top-scoring tokens to consider for each generation step."""
|
||||||
|
|
||||||
frequency_penalty: Optional[Penalty] = None
|
frequency_penalty: Optional[Any] = None
|
||||||
"""A penalty applied to tokens that are frequently generated."""
|
"""A penalty applied to tokens that are frequently generated."""
|
||||||
|
|
||||||
presence_penalty: Optional[Penalty] = None
|
presence_penalty: Optional[Any] = None
|
||||||
""" A penalty applied to tokens that are already present in the prompt."""
|
""" A penalty applied to tokens that are already present in the prompt."""
|
||||||
|
|
||||||
count_penalty: Optional[Penalty] = None
|
count_penalty: Optional[Any] = None
|
||||||
"""A penalty applied to tokens based on their frequency
|
"""A penalty applied to tokens based on their frequency
|
||||||
in the generated responses."""
|
in the generated responses."""
|
||||||
|
|
||||||
@ -129,6 +129,51 @@ class ChatAI21(BaseChatModel, AI21Base):
|
|||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
return "chat-ai21"
|
return "chat-ai21"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Mapping[str, Any]:
|
||||||
|
base_params = {
|
||||||
|
"model": self.model,
|
||||||
|
"num_results": self.num_results,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"min_tokens": self.min_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k_return": self.top_k_return,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.count_penalty is not None:
|
||||||
|
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||||
|
|
||||||
|
if self.frequency_penalty is not None:
|
||||||
|
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
||||||
|
|
||||||
|
if self.presence_penalty is not None:
|
||||||
|
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
||||||
|
|
||||||
|
return base_params
|
||||||
|
|
||||||
|
def _build_params_for_request(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
params = {}
|
||||||
|
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in kwargs:
|
||||||
|
raise ValueError("stop is defined in both stop and kwargs")
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
return {
|
||||||
|
"system": system or "",
|
||||||
|
"messages": ai21_messages,
|
||||||
|
**self._default_params,
|
||||||
|
**params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -136,24 +181,9 @@ class ChatAI21(BaseChatModel, AI21Base):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
|
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
|
||||||
|
|
||||||
response = self.client.chat.create(
|
response = self.client.chat.create(**params)
|
||||||
model=self.model,
|
|
||||||
messages=ai21_messages,
|
|
||||||
system=system or "",
|
|
||||||
num_results=self.num_results,
|
|
||||||
temperature=self.temperature,
|
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
min_tokens=self.min_tokens,
|
|
||||||
top_p=self.top_p,
|
|
||||||
top_k_return=self.top_k_return,
|
|
||||||
stop_sequences=stop,
|
|
||||||
frequency_penalty=self.frequency_penalty,
|
|
||||||
presence_penalty=self.presence_penalty,
|
|
||||||
count_penalty=self.count_penalty,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = response.outputs
|
outputs = response.outputs
|
||||||
message = AIMessage(content=outputs[0].text)
|
message = AIMessage(content=outputs[0].text)
|
||||||
|
@ -3,10 +3,11 @@ from functools import partial
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ai21.models import CompletionsResponse, Penalty
|
from ai21.models import CompletionsResponse
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -47,16 +48,16 @@ class AI21LLM(BaseLLM, AI21Base):
|
|||||||
top_p: float = 1
|
top_p: float = 1
|
||||||
"""A value controlling the diversity of the model's responses."""
|
"""A value controlling the diversity of the model's responses."""
|
||||||
|
|
||||||
top_k_returns: int = 0
|
top_k_return: int = 0
|
||||||
"""The number of top-scoring tokens to consider for each generation step."""
|
"""The number of top-scoring tokens to consider for each generation step."""
|
||||||
|
|
||||||
frequency_penalty: Optional[Penalty] = None
|
frequency_penalty: Optional[Any] = None
|
||||||
"""A penalty applied to tokens that are frequently generated."""
|
"""A penalty applied to tokens that are frequently generated."""
|
||||||
|
|
||||||
presence_penalty: Optional[Penalty] = None
|
presence_penalty: Optional[Any] = None
|
||||||
""" A penalty applied to tokens that are already present in the prompt."""
|
""" A penalty applied to tokens that are already present in the prompt."""
|
||||||
|
|
||||||
count_penalty: Optional[Penalty] = None
|
count_penalty: Optional[Any] = None
|
||||||
"""A penalty applied to tokens based on their frequency
|
"""A penalty applied to tokens based on their frequency
|
||||||
in the generated responses."""
|
in the generated responses."""
|
||||||
|
|
||||||
@ -73,6 +74,51 @@ class AI21LLM(BaseLLM, AI21Base):
|
|||||||
"""Return type of LLM."""
|
"""Return type of LLM."""
|
||||||
return "ai21-llm"
|
return "ai21-llm"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Mapping[str, Any]:
|
||||||
|
base_params = {
|
||||||
|
"model": self.model,
|
||||||
|
"num_results": self.num_results,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"min_tokens": self.min_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k_return": self.top_k_return,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.count_penalty is not None:
|
||||||
|
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||||
|
|
||||||
|
if self.custom_model is not None:
|
||||||
|
base_params["custom_model"] = self.custom_model
|
||||||
|
|
||||||
|
if self.epoch is not None:
|
||||||
|
base_params["epoch"] = self.epoch
|
||||||
|
|
||||||
|
if self.frequency_penalty is not None:
|
||||||
|
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
||||||
|
|
||||||
|
if self.presence_penalty is not None:
|
||||||
|
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
||||||
|
|
||||||
|
return base_params
|
||||||
|
|
||||||
|
def _build_params_for_request(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in kwargs:
|
||||||
|
raise ValueError("stop is defined in both stop and kwargs")
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self._default_params,
|
||||||
|
**params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -83,10 +129,10 @@ class AI21LLM(BaseLLM, AI21Base):
|
|||||||
generations: List[List[Generation]] = []
|
generations: List[List[Generation]] = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
|
|
||||||
|
params = self._build_params_for_request(stop=stop, **kwargs)
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
response = self._invoke_completion(
|
response = self._invoke_completion(prompt=prompt, **params)
|
||||||
prompt=prompt, model=self.model, stop_sequences=stop, **kwargs
|
|
||||||
)
|
|
||||||
generation = self._response_to_generation(response)
|
generation = self._response_to_generation(response)
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
token_count += self.client.count_tokens(prompt)
|
token_count += self.client.count_tokens(prompt)
|
||||||
@ -109,25 +155,11 @@ class AI21LLM(BaseLLM, AI21Base):
|
|||||||
def _invoke_completion(
|
def _invoke_completion(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> CompletionsResponse:
|
) -> CompletionsResponse:
|
||||||
return self.client.completion.create(
|
return self.client.completion.create(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
**kwargs,
|
||||||
max_tokens=self.max_tokens,
|
|
||||||
num_results=self.num_results,
|
|
||||||
min_tokens=self.min_tokens,
|
|
||||||
temperature=self.temperature,
|
|
||||||
top_p=self.top_p,
|
|
||||||
top_k_return=self.top_k_returns,
|
|
||||||
custom_model=self.custom_model,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
frequency_penalty=self.frequency_penalty,
|
|
||||||
presence_penalty=self.presence_penalty,
|
|
||||||
count_penalty=self.count_penalty,
|
|
||||||
epoch=self.epoch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _response_to_generation(
|
def _response_to_generation(
|
||||||
|
14
libs/partners/ai21/poetry.lock
generated
14
libs/partners/ai21/poetry.lock
generated
@ -1,20 +1,21 @@
|
|||||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ai21"
|
name = "ai21"
|
||||||
version = "2.0.5"
|
version = "2.1.2"
|
||||||
description = ""
|
description = ""
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8,<4.0"
|
python-versions = ">=3.8,<4.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "ai21-2.0.5-py3-none-any.whl", hash = "sha256:b10b66d9a8a3a7a010e3f67d39049913ea38d23034a34d18940a1801412bae10"},
|
{file = "ai21-2.1.2-py3-none-any.whl", hash = "sha256:5ca1b5e1f11dc52fd3e894edb288634572ee6f13d8fe456c66a1825812067548"},
|
||||||
{file = "ai21-2.0.5.tar.gz", hash = "sha256:3b0c933b7b7268f9d7615f6cab23febc7d78da887eaef002cdc24cd1e9971851"},
|
{file = "ai21-2.1.2.tar.gz", hash = "sha256:8968a2b4a98fdc5b1bca4a9c856a903fa874d0762a2570741efa071b65a1accd"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
ai21-tokenizer = ">=0.3.9,<0.4.0"
|
ai21-tokenizer = ">=0.3.9,<0.4.0"
|
||||||
dataclasses-json = ">=0.6.3,<0.7.0"
|
dataclasses-json = ">=0.6.3,<0.7.0"
|
||||||
requests = ">=2.31.0,<3.0.0"
|
requests = ">=2.31.0,<3.0.0"
|
||||||
|
typing-extensions = ">=4.9.0,<5.0.0"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
aws = ["boto3 (>=1.28.82,<2.0.0)"]
|
aws = ["boto3 (>=1.28.82,<2.0.0)"]
|
||||||
@ -299,7 +300,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.27"
|
version = "0.1.30"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -725,7 +726,6 @@ files = [
|
|||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
@ -1009,4 +1009,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "f8512fa6e745dc32b49132bd7327ec204822c861c113b57db21f1f8cf2fac881"
|
content-hash = "3073522be06765f2acb7efea6ed1fcc49eaa05e82534d96fa914899dbbbb541f"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-ai21"
|
name = "langchain-ai21"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
description = "An integration package connecting AI21 and LangChain"
|
description = "An integration package connecting AI21 and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -8,7 +8,7 @@ readme = "README.md"
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = "^0.1.22"
|
langchain-core = "^0.1.22"
|
||||||
ai21 = "2.0.5"
|
ai21 = "^2.1.2"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
@ -31,11 +32,30 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
|
|||||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
||||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
||||||
"count_penalty": Penalty(
|
"count_penalty": Penalty(
|
||||||
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
|
scale=0.2,
|
||||||
|
apply_to_punctuation=True,
|
||||||
|
apply_to_emojis=True,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
|
||||||
|
"num_results": 3,
|
||||||
|
"max_tokens": 20,
|
||||||
|
"min_tokens": 10,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_p": 0.5,
|
||||||
|
"top_k_return": 0,
|
||||||
|
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
|
||||||
|
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
|
||||||
|
"count_penalty": Penalty(
|
||||||
|
scale=0.2,
|
||||||
|
apply_to_punctuation=True,
|
||||||
|
apply_to_emojis=True,
|
||||||
|
).to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mocked_completion_response(mocker: MockerFixture) -> Mock:
|
def mocked_completion_response(mocker: MockerFixture) -> Mock:
|
||||||
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
|
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
|
||||||
@ -86,10 +106,12 @@ def temporarily_unset_api_key() -> Generator:
|
|||||||
"""
|
"""
|
||||||
api_key = AI21EnvConfig.api_key
|
api_key = AI21EnvConfig.api_key
|
||||||
AI21EnvConfig.api_key = None
|
AI21EnvConfig.api_key = None
|
||||||
|
os.environ.pop("AI21_API_KEY", None)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
AI21EnvConfig.api_key = api_key
|
AI21EnvConfig.api_key = api_key
|
||||||
|
os.environ["AI21_API_KEY"] = api_key
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -22,6 +22,7 @@ from langchain_ai21.chat_models import (
|
|||||||
)
|
)
|
||||||
from tests.unit_tests.conftest import (
|
from tests.unit_tests.conftest import (
|
||||||
BASIC_EXAMPLE_LLM_PARAMETERS,
|
BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||||
|
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
DUMMY_API_KEY,
|
DUMMY_API_KEY,
|
||||||
temporarily_unset_api_key,
|
temporarily_unset_api_key,
|
||||||
)
|
)
|
||||||
@ -46,7 +47,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
|||||||
min_tokens = 20
|
min_tokens = 20
|
||||||
temperature = 0.1
|
temperature = 0.1
|
||||||
top_p = 0.1
|
top_p = 0.1
|
||||||
top_k_returns = 0
|
top_k_return = 0
|
||||||
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True)
|
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True)
|
||||||
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True)
|
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True)
|
||||||
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True)
|
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True)
|
||||||
@ -59,7 +60,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
|||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k_returns=top_k_returns,
|
top_k_return=top_k_return,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
count_penalty=count_penalty,
|
count_penalty=count_penalty,
|
||||||
@ -70,7 +71,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
|||||||
assert llm.min_tokens == min_tokens
|
assert llm.min_tokens == min_tokens
|
||||||
assert llm.temperature == temperature
|
assert llm.temperature == temperature
|
||||||
assert llm.top_p == top_p
|
assert llm.top_p == top_p
|
||||||
assert llm.top_k_return == top_k_returns
|
assert llm.top_k_return == top_k_return
|
||||||
assert llm.frequency_penalty == frequency_penalty
|
assert llm.frequency_penalty == frequency_penalty
|
||||||
assert llm.presence_penalty == presence_penalty
|
assert llm.presence_penalty == presence_penalty
|
||||||
assert count_penalty == count_penalty
|
assert count_penalty == count_penalty
|
||||||
@ -180,14 +181,14 @@ def test_invoke(mock_client_with_chat: Mock) -> None:
|
|||||||
client=mock_client_with_chat,
|
client=mock_client_with_chat,
|
||||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||||
)
|
)
|
||||||
llm.invoke(input=chat_input, config=dict(tags=["foo"]))
|
llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"])
|
||||||
|
|
||||||
mock_client_with_chat.chat.create.assert_called_once_with(
|
mock_client_with_chat.chat.create.assert_called_once_with(
|
||||||
model="j2-ultra",
|
model="j2-ultra",
|
||||||
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
|
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
|
||||||
system="",
|
system="",
|
||||||
stop_sequences=None,
|
stop_sequences=["\n"],
|
||||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -223,8 +224,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
|
|||||||
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
|
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
stop_sequences=None,
|
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
|
||||||
),
|
),
|
||||||
call(
|
call(
|
||||||
model="j2-ultra",
|
model="j2-ultra",
|
||||||
@ -232,8 +232,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
|
|||||||
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
|
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
|
||||||
],
|
],
|
||||||
system="system message",
|
system="system message",
|
||||||
stop_sequences=None,
|
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -10,6 +10,7 @@ from ai21.models import (
|
|||||||
from langchain_ai21 import AI21LLM
|
from langchain_ai21 import AI21LLM
|
||||||
from tests.unit_tests.conftest import (
|
from tests.unit_tests.conftest import (
|
||||||
BASIC_EXAMPLE_LLM_PARAMETERS,
|
BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||||
|
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
DUMMY_API_KEY,
|
DUMMY_API_KEY,
|
||||||
temporarily_unset_api_key,
|
temporarily_unset_api_key,
|
||||||
)
|
)
|
||||||
@ -42,7 +43,7 @@ def test_initialization__when_custom_parameters_to_init() -> None:
|
|||||||
min_tokens=10,
|
min_tokens=10,
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
top_p=0.5,
|
top_p=0.5,
|
||||||
top_k_returns=0,
|
top_k_return=0,
|
||||||
stop_sequences=["\n"],
|
stop_sequences=["\n"],
|
||||||
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True),
|
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True),
|
||||||
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True),
|
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True),
|
||||||
@ -93,7 +94,7 @@ def test_generate(mock_client_with_completion: Mock) -> None:
|
|||||||
custom_model=custom_model,
|
custom_model=custom_model,
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
),
|
),
|
||||||
call(
|
call(
|
||||||
prompt=prompt1,
|
prompt=prompt1,
|
||||||
@ -101,7 +102,7 @@ def test_generate(mock_client_with_completion: Mock) -> None:
|
|||||||
custom_model=custom_model,
|
custom_model=custom_model,
|
||||||
stop_sequences=stop,
|
stop_sequences=stop,
|
||||||
epoch=epoch,
|
epoch=epoch,
|
||||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user