ai21[patch]: AI21 Labs bump SDK version (#19114)

Description: Added support AI21 SDK version 2.1.2
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Asaf Joseph Gardin 2024-03-19 04:47:08 +02:00 committed by GitHub
parent edf9d1c905
commit 21c45475c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 154 additions and 70 deletions

View File

@ -1,5 +1,5 @@
import os
from typing import Dict, Optional
from typing import Any, Dict, Optional
from ai21 import AI21Client
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
@ -12,7 +12,7 @@ class AI21Base(BaseModel):
class Config:
arbitrary_types_allowed = True
client: AI21Client = Field(default=None)
client: Any = Field(default=None, exclude=True) #: :meta private:
api_key: Optional[SecretStr] = None
api_host: Optional[str] = None
timeout_sec: Optional[float] = None

View File

@ -1,8 +1,8 @@
import asyncio
from functools import partial
from typing import Any, List, Optional, Tuple, cast
from typing import Any, List, Mapping, Optional, Tuple, cast
from ai21.models import ChatMessage, Penalty, RoleType
from ai21.models import ChatMessage, RoleType
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -109,13 +109,13 @@ class ChatAI21(BaseChatModel, AI21Base):
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""
frequency_penalty: Optional[Penalty] = None
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""
presence_penalty: Optional[Penalty] = None
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""
count_penalty: Optional[Penalty] = None
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""
@ -129,6 +129,51 @@ class ChatAI21(BaseChatModel, AI21Base):
"""Return type of chat model."""
return "chat-ai21"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _build_params_for_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Mapping[str, Any]:
params = {}
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
"system": system or "",
"messages": ai21_messages,
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
messages: List[BaseMessage],
@ -136,24 +181,9 @@ class ChatAI21(BaseChatModel, AI21Base):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
response = self.client.chat.create(
model=self.model,
messages=ai21_messages,
system=system or "",
num_results=self.num_results,
temperature=self.temperature,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
top_p=self.top_p,
top_k_return=self.top_k_return,
stop_sequences=stop,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
count_penalty=self.count_penalty,
**kwargs,
)
response = self.client.chat.create(**params)
outputs = response.outputs
message = AIMessage(content=outputs[0].text)

View File

@ -3,10 +3,11 @@ from functools import partial
from typing import (
Any,
List,
Mapping,
Optional,
)
from ai21.models import CompletionsResponse, Penalty
from ai21.models import CompletionsResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -47,16 +48,16 @@ class AI21LLM(BaseLLM, AI21Base):
top_p: float = 1
"""A value controlling the diversity of the model's responses."""
top_k_returns: int = 0
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""
frequency_penalty: Optional[Penalty] = None
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""
presence_penalty: Optional[Penalty] = None
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""
count_penalty: Optional[Penalty] = None
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""
@ -73,6 +74,51 @@ class AI21LLM(BaseLLM, AI21Base):
"""Return type of LLM."""
return "ai21-llm"
@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}
if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()
if self.custom_model is not None:
base_params["custom_model"] = self.custom_model
if self.epoch is not None:
base_params["epoch"] = self.epoch
if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()
return base_params
def _build_params_for_request(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Mapping[str, Any]:
params = {}
if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop
return {
**self._default_params,
**params,
**kwargs,
}
def _generate(
self,
prompts: List[str],
@ -83,10 +129,10 @@ class AI21LLM(BaseLLM, AI21Base):
generations: List[List[Generation]] = []
token_count = 0
params = self._build_params_for_request(stop=stop, **kwargs)
for prompt in prompts:
response = self._invoke_completion(
prompt=prompt, model=self.model, stop_sequences=stop, **kwargs
)
response = self._invoke_completion(prompt=prompt, **params)
generation = self._response_to_generation(response)
generations.append(generation)
token_count += self.client.count_tokens(prompt)
@ -109,25 +155,11 @@ class AI21LLM(BaseLLM, AI21Base):
def _invoke_completion(
self,
prompt: str,
model: str,
stop_sequences: Optional[List[str]] = None,
**kwargs: Any,
) -> CompletionsResponse:
return self.client.completion.create(
prompt=prompt,
model=model,
max_tokens=self.max_tokens,
num_results=self.num_results,
min_tokens=self.min_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k_return=self.top_k_returns,
custom_model=self.custom_model,
stop_sequences=stop_sequences,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
count_penalty=self.count_penalty,
epoch=self.epoch,
**kwargs,
)
def _response_to_generation(

View File

@ -1,20 +1,21 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "ai21"
version = "2.0.5"
version = "2.1.2"
description = ""
optional = false
python-versions = ">=3.8,<4.0"
files = [
{file = "ai21-2.0.5-py3-none-any.whl", hash = "sha256:b10b66d9a8a3a7a010e3f67d39049913ea38d23034a34d18940a1801412bae10"},
{file = "ai21-2.0.5.tar.gz", hash = "sha256:3b0c933b7b7268f9d7615f6cab23febc7d78da887eaef002cdc24cd1e9971851"},
{file = "ai21-2.1.2-py3-none-any.whl", hash = "sha256:5ca1b5e1f11dc52fd3e894edb288634572ee6f13d8fe456c66a1825812067548"},
{file = "ai21-2.1.2.tar.gz", hash = "sha256:8968a2b4a98fdc5b1bca4a9c856a903fa874d0762a2570741efa071b65a1accd"},
]
[package.dependencies]
ai21-tokenizer = ">=0.3.9,<0.4.0"
dataclasses-json = ">=0.6.3,<0.7.0"
requests = ">=2.31.0,<3.0.0"
typing-extensions = ">=4.9.0,<5.0.0"
[package.extras]
aws = ["boto3 (>=1.28.82,<2.0.0)"]
@ -299,7 +300,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.27"
version = "0.1.30"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -725,7 +726,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1009,4 +1009,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "f8512fa6e745dc32b49132bd7327ec204822c861c113b57db21f1f8cf2fac881"
content-hash = "3073522be06765f2acb7efea6ed1fcc49eaa05e82534d96fa914899dbbbb541f"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-ai21"
version = "0.1.1"
version = "0.1.2"
description = "An integration package connecting AI21 and LangChain"
authors = []
readme = "README.md"
@ -8,7 +8,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.22"
ai21 = "2.0.5"
ai21 = "^2.1.2"
[tool.poetry.group.test]
optional = true

View File

@ -1,3 +1,4 @@
import os
from contextlib import contextmanager
from typing import Generator
from unittest.mock import Mock
@ -31,11 +32,30 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
}
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
}
@pytest.fixture
def mocked_completion_response(mocker: MockerFixture) -> Mock:
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
@ -86,10 +106,12 @@ def temporarily_unset_api_key() -> Generator:
"""
api_key = AI21EnvConfig.api_key
AI21EnvConfig.api_key = None
os.environ.pop("AI21_API_KEY", None)
yield
if api_key is not None:
AI21EnvConfig.api_key = api_key
os.environ["AI21_API_KEY"] = api_key
@pytest.fixture

View File

@ -22,6 +22,7 @@ from langchain_ai21.chat_models import (
)
from tests.unit_tests.conftest import (
BASIC_EXAMPLE_LLM_PARAMETERS,
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
DUMMY_API_KEY,
temporarily_unset_api_key,
)
@ -46,7 +47,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
min_tokens = 20
temperature = 0.1
top_p = 0.1
top_k_returns = 0
top_k_return = 0
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True)
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True)
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True)
@ -59,7 +60,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
min_tokens=min_tokens,
temperature=temperature,
top_p=top_p,
top_k_returns=top_k_returns,
top_k_return=top_k_return,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
count_penalty=count_penalty,
@ -70,7 +71,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
assert llm.min_tokens == min_tokens
assert llm.temperature == temperature
assert llm.top_p == top_p
assert llm.top_k_return == top_k_returns
assert llm.top_k_return == top_k_return
assert llm.frequency_penalty == frequency_penalty
assert llm.presence_penalty == presence_penalty
assert count_penalty == count_penalty
@ -180,14 +181,14 @@ def test_invoke(mock_client_with_chat: Mock) -> None:
client=mock_client_with_chat,
**BASIC_EXAMPLE_LLM_PARAMETERS,
)
llm.invoke(input=chat_input, config=dict(tags=["foo"]))
llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"])
mock_client_with_chat.chat.create.assert_called_once_with(
model="j2-ultra",
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
system="",
stop_sequences=None,
**BASIC_EXAMPLE_LLM_PARAMETERS,
stop_sequences=["\n"],
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
)
@ -223,8 +224,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
],
system="",
stop_sequences=None,
**BASIC_EXAMPLE_LLM_PARAMETERS,
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
),
call(
model="j2-ultra",
@ -232,8 +232,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
],
system="system message",
stop_sequences=None,
**BASIC_EXAMPLE_LLM_PARAMETERS,
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
),
]
)

View File

@ -10,6 +10,7 @@ from ai21.models import (
from langchain_ai21 import AI21LLM
from tests.unit_tests.conftest import (
BASIC_EXAMPLE_LLM_PARAMETERS,
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
DUMMY_API_KEY,
temporarily_unset_api_key,
)
@ -42,7 +43,7 @@ def test_initialization__when_custom_parameters_to_init() -> None:
min_tokens=10,
temperature=0.5,
top_p=0.5,
top_k_returns=0,
top_k_return=0,
stop_sequences=["\n"],
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True),
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True),
@ -93,7 +94,7 @@ def test_generate(mock_client_with_completion: Mock) -> None:
custom_model=custom_model,
stop_sequences=stop,
epoch=epoch,
**BASIC_EXAMPLE_LLM_PARAMETERS,
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
),
call(
prompt=prompt1,
@ -101,7 +102,7 @@ def test_generate(mock_client_with_completion: Mock) -> None:
custom_model=custom_model,
stop_sequences=stop,
epoch=epoch,
**BASIC_EXAMPLE_LLM_PARAMETERS,
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
),
]
)