mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
ai21[patch]: AI21 Labs bump SDK version (#19114)
Description: Added support AI21 SDK version 2.1.2 Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
edf9d1c905
commit
21c45475c5
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ai21 import AI21Client
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
@ -12,7 +12,7 @@ class AI21Base(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
client: AI21Client = Field(default=None)
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
api_key: Optional[SecretStr] = None
|
||||
api_host: Optional[str] = None
|
||||
timeout_sec: Optional[float] = None
|
||||
|
@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, Tuple, cast
|
||||
from typing import Any, List, Mapping, Optional, Tuple, cast
|
||||
|
||||
from ai21.models import ChatMessage, Penalty, RoleType
|
||||
from ai21.models import ChatMessage, RoleType
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -109,13 +109,13 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
top_k_return: int = 0
|
||||
"""The number of top-scoring tokens to consider for each generation step."""
|
||||
|
||||
frequency_penalty: Optional[Penalty] = None
|
||||
frequency_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens that are frequently generated."""
|
||||
|
||||
presence_penalty: Optional[Penalty] = None
|
||||
presence_penalty: Optional[Any] = None
|
||||
""" A penalty applied to tokens that are already present in the prompt."""
|
||||
|
||||
count_penalty: Optional[Penalty] = None
|
||||
count_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens based on their frequency
|
||||
in the generated responses."""
|
||||
|
||||
@ -129,6 +129,51 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
"""Return type of chat model."""
|
||||
return "chat-ai21"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
base_params = {
|
||||
"model": self.model,
|
||||
"num_results": self.num_results,
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k_return": self.top_k_return,
|
||||
}
|
||||
|
||||
if self.count_penalty is not None:
|
||||
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||
|
||||
if self.frequency_penalty is not None:
|
||||
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
||||
|
||||
if self.presence_penalty is not None:
|
||||
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
||||
|
||||
return base_params
|
||||
|
||||
def _build_params_for_request(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Mapping[str, Any]:
|
||||
params = {}
|
||||
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in kwargs:
|
||||
raise ValueError("stop is defined in both stop and kwargs")
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
return {
|
||||
"system": system or "",
|
||||
"messages": ai21_messages,
|
||||
**self._default_params,
|
||||
**params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -136,24 +181,9 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
|
||||
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
|
||||
|
||||
response = self.client.chat.create(
|
||||
model=self.model,
|
||||
messages=ai21_messages,
|
||||
system=system or "",
|
||||
num_results=self.num_results,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
top_p=self.top_p,
|
||||
top_k_return=self.top_k_return,
|
||||
stop_sequences=stop,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
presence_penalty=self.presence_penalty,
|
||||
count_penalty=self.count_penalty,
|
||||
**kwargs,
|
||||
)
|
||||
response = self.client.chat.create(**params)
|
||||
|
||||
outputs = response.outputs
|
||||
message = AIMessage(content=outputs[0].text)
|
||||
|
@ -3,10 +3,11 @@ from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from ai21.models import CompletionsResponse, Penalty
|
||||
from ai21.models import CompletionsResponse
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -47,16 +48,16 @@ class AI21LLM(BaseLLM, AI21Base):
|
||||
top_p: float = 1
|
||||
"""A value controlling the diversity of the model's responses."""
|
||||
|
||||
top_k_returns: int = 0
|
||||
top_k_return: int = 0
|
||||
"""The number of top-scoring tokens to consider for each generation step."""
|
||||
|
||||
frequency_penalty: Optional[Penalty] = None
|
||||
frequency_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens that are frequently generated."""
|
||||
|
||||
presence_penalty: Optional[Penalty] = None
|
||||
presence_penalty: Optional[Any] = None
|
||||
""" A penalty applied to tokens that are already present in the prompt."""
|
||||
|
||||
count_penalty: Optional[Penalty] = None
|
||||
count_penalty: Optional[Any] = None
|
||||
"""A penalty applied to tokens based on their frequency
|
||||
in the generated responses."""
|
||||
|
||||
@ -73,6 +74,51 @@ class AI21LLM(BaseLLM, AI21Base):
|
||||
"""Return type of LLM."""
|
||||
return "ai21-llm"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
base_params = {
|
||||
"model": self.model,
|
||||
"num_results": self.num_results,
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k_return": self.top_k_return,
|
||||
}
|
||||
|
||||
if self.count_penalty is not None:
|
||||
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||
|
||||
if self.custom_model is not None:
|
||||
base_params["custom_model"] = self.custom_model
|
||||
|
||||
if self.epoch is not None:
|
||||
base_params["epoch"] = self.epoch
|
||||
|
||||
if self.frequency_penalty is not None:
|
||||
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()
|
||||
|
||||
if self.presence_penalty is not None:
|
||||
base_params["presence_penalty"] = self.presence_penalty.to_dict()
|
||||
|
||||
return base_params
|
||||
|
||||
def _build_params_for_request(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Mapping[str, Any]:
|
||||
params = {}
|
||||
|
||||
if stop is not None:
|
||||
if "stop" in kwargs:
|
||||
raise ValueError("stop is defined in both stop and kwargs")
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
return {
|
||||
**self._default_params,
|
||||
**params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -83,10 +129,10 @@ class AI21LLM(BaseLLM, AI21Base):
|
||||
generations: List[List[Generation]] = []
|
||||
token_count = 0
|
||||
|
||||
params = self._build_params_for_request(stop=stop, **kwargs)
|
||||
|
||||
for prompt in prompts:
|
||||
response = self._invoke_completion(
|
||||
prompt=prompt, model=self.model, stop_sequences=stop, **kwargs
|
||||
)
|
||||
response = self._invoke_completion(prompt=prompt, **params)
|
||||
generation = self._response_to_generation(response)
|
||||
generations.append(generation)
|
||||
token_count += self.client.count_tokens(prompt)
|
||||
@ -109,25 +155,11 @@ class AI21LLM(BaseLLM, AI21Base):
|
||||
def _invoke_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CompletionsResponse:
|
||||
return self.client.completion.create(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
max_tokens=self.max_tokens,
|
||||
num_results=self.num_results,
|
||||
min_tokens=self.min_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k_return=self.top_k_returns,
|
||||
custom_model=self.custom_model,
|
||||
stop_sequences=stop_sequences,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
presence_penalty=self.presence_penalty,
|
||||
count_penalty=self.count_penalty,
|
||||
epoch=self.epoch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _response_to_generation(
|
||||
|
14
libs/partners/ai21/poetry.lock
generated
14
libs/partners/ai21/poetry.lock
generated
@ -1,20 +1,21 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "ai21"
|
||||
version = "2.0.5"
|
||||
version = "2.1.2"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "ai21-2.0.5-py3-none-any.whl", hash = "sha256:b10b66d9a8a3a7a010e3f67d39049913ea38d23034a34d18940a1801412bae10"},
|
||||
{file = "ai21-2.0.5.tar.gz", hash = "sha256:3b0c933b7b7268f9d7615f6cab23febc7d78da887eaef002cdc24cd1e9971851"},
|
||||
{file = "ai21-2.1.2-py3-none-any.whl", hash = "sha256:5ca1b5e1f11dc52fd3e894edb288634572ee6f13d8fe456c66a1825812067548"},
|
||||
{file = "ai21-2.1.2.tar.gz", hash = "sha256:8968a2b4a98fdc5b1bca4a9c856a903fa874d0762a2570741efa071b65a1accd"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
ai21-tokenizer = ">=0.3.9,<0.4.0"
|
||||
dataclasses-json = ">=0.6.3,<0.7.0"
|
||||
requests = ">=2.31.0,<3.0.0"
|
||||
typing-extensions = ">=4.9.0,<5.0.0"
|
||||
|
||||
[package.extras]
|
||||
aws = ["boto3 (>=1.28.82,<2.0.0)"]
|
||||
@ -299,7 +300,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.27"
|
||||
version = "0.1.30"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -725,7 +726,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -1009,4 +1009,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "f8512fa6e745dc32b49132bd7327ec204822c861c113b57db21f1f8cf2fac881"
|
||||
content-hash = "3073522be06765f2acb7efea6ed1fcc49eaa05e82534d96fa914899dbbbb541f"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-ai21"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
description = "An integration package connecting AI21 and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
@ -8,7 +8,7 @@ readme = "README.md"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.22"
|
||||
ai21 = "2.0.5"
|
||||
ai21 = "^2.1.2"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from unittest.mock import Mock
|
||||
@ -31,11 +32,30 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
||||
"count_penalty": Penalty(
|
||||
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
|
||||
"num_results": 3,
|
||||
"max_tokens": 20,
|
||||
"min_tokens": 10,
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
|
||||
"count_penalty": Penalty(
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
).to_dict(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_completion_response(mocker: MockerFixture) -> Mock:
|
||||
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
|
||||
@ -86,10 +106,12 @@ def temporarily_unset_api_key() -> Generator:
|
||||
"""
|
||||
api_key = AI21EnvConfig.api_key
|
||||
AI21EnvConfig.api_key = None
|
||||
os.environ.pop("AI21_API_KEY", None)
|
||||
yield
|
||||
|
||||
if api_key is not None:
|
||||
AI21EnvConfig.api_key = api_key
|
||||
os.environ["AI21_API_KEY"] = api_key
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -22,6 +22,7 @@ from langchain_ai21.chat_models import (
|
||||
)
|
||||
from tests.unit_tests.conftest import (
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
DUMMY_API_KEY,
|
||||
temporarily_unset_api_key,
|
||||
)
|
||||
@ -46,7 +47,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
||||
min_tokens = 20
|
||||
temperature = 0.1
|
||||
top_p = 0.1
|
||||
top_k_returns = 0
|
||||
top_k_return = 0
|
||||
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True)
|
||||
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True)
|
||||
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True)
|
||||
@ -59,7 +60,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
||||
min_tokens=min_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k_returns=top_k_returns,
|
||||
top_k_return=top_k_return,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
count_penalty=count_penalty,
|
||||
@ -70,7 +71,7 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
||||
assert llm.min_tokens == min_tokens
|
||||
assert llm.temperature == temperature
|
||||
assert llm.top_p == top_p
|
||||
assert llm.top_k_return == top_k_returns
|
||||
assert llm.top_k_return == top_k_return
|
||||
assert llm.frequency_penalty == frequency_penalty
|
||||
assert llm.presence_penalty == presence_penalty
|
||||
assert count_penalty == count_penalty
|
||||
@ -180,14 +181,14 @@ def test_invoke(mock_client_with_chat: Mock) -> None:
|
||||
client=mock_client_with_chat,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
)
|
||||
llm.invoke(input=chat_input, config=dict(tags=["foo"]))
|
||||
llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"])
|
||||
|
||||
mock_client_with_chat.chat.create.assert_called_once_with(
|
||||
model="j2-ultra",
|
||||
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
|
||||
system="",
|
||||
stop_sequences=None,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
stop_sequences=["\n"],
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
)
|
||||
|
||||
|
||||
@ -223,8 +224,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
|
||||
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
|
||||
],
|
||||
system="",
|
||||
stop_sequences=None,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
),
|
||||
call(
|
||||
model="j2-ultra",
|
||||
@ -232,8 +232,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
|
||||
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
|
||||
],
|
||||
system="system message",
|
||||
stop_sequences=None,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from ai21.models import (
|
||||
from langchain_ai21 import AI21LLM
|
||||
from tests.unit_tests.conftest import (
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
DUMMY_API_KEY,
|
||||
temporarily_unset_api_key,
|
||||
)
|
||||
@ -42,7 +43,7 @@ def test_initialization__when_custom_parameters_to_init() -> None:
|
||||
min_tokens=10,
|
||||
temperature=0.5,
|
||||
top_p=0.5,
|
||||
top_k_returns=0,
|
||||
top_k_return=0,
|
||||
stop_sequences=["\n"],
|
||||
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True),
|
||||
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True),
|
||||
@ -93,7 +94,7 @@ def test_generate(mock_client_with_completion: Mock) -> None:
|
||||
custom_model=custom_model,
|
||||
stop_sequences=stop,
|
||||
epoch=epoch,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
),
|
||||
call(
|
||||
prompt=prompt1,
|
||||
@ -101,7 +102,7 @@ def test_generate(mock_client_with_completion: Mock) -> None:
|
||||
custom_model=custom_model,
|
||||
stop_sequences=stop,
|
||||
epoch=epoch,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS,
|
||||
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user