diff --git a/libs/partners/ai21/langchain_ai21/ai21_base.py b/libs/partners/ai21/langchain_ai21/ai21_base.py index 39c5ffbf1f0..fa9f30ed803 100644 --- a/libs/partners/ai21/langchain_ai21/ai21_base.py +++ b/libs/partners/ai21/langchain_ai21/ai21_base.py @@ -1,5 +1,5 @@ import os -from typing import Dict, Optional +from typing import Any, Dict, Optional from ai21 import AI21Client from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator @@ -12,7 +12,7 @@ class AI21Base(BaseModel): class Config: arbitrary_types_allowed = True - client: AI21Client = Field(default=None) + client: Any = Field(default=None, exclude=True) #: :meta private: api_key: Optional[SecretStr] = None api_host: Optional[str] = None timeout_sec: Optional[float] = None diff --git a/libs/partners/ai21/langchain_ai21/chat_models.py b/libs/partners/ai21/langchain_ai21/chat_models.py index 0839a493862..064e7dae93e 100644 --- a/libs/partners/ai21/langchain_ai21/chat_models.py +++ b/libs/partners/ai21/langchain_ai21/chat_models.py @@ -1,8 +1,8 @@ import asyncio from functools import partial -from typing import Any, List, Optional, Tuple, cast +from typing import Any, List, Mapping, Optional, Tuple, cast -from ai21.models import ChatMessage, Penalty, RoleType +from ai21.models import ChatMessage, RoleType from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -109,13 +109,13 @@ class ChatAI21(BaseChatModel, AI21Base): top_k_return: int = 0 """The number of top-scoring tokens to consider for each generation step.""" - frequency_penalty: Optional[Penalty] = None + frequency_penalty: Optional[Any] = None """A penalty applied to tokens that are frequently generated.""" - presence_penalty: Optional[Penalty] = None + presence_penalty: Optional[Any] = None """ A penalty applied to tokens that are already present in the prompt.""" - count_penalty: Optional[Penalty] = None + count_penalty: Optional[Any] = None """A penalty applied to tokens based on their frequency in the generated responses.""" @@ -129,6 +129,51 @@ class ChatAI21(BaseChatModel, AI21Base): """Return type of chat model.""" return "chat-ai21" + @property + def _default_params(self) -> Mapping[str, Any]: + base_params = { + "model": self.model, + "num_results": self.num_results, + "max_tokens": self.max_tokens, + "min_tokens": self.min_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k_return": self.top_k_return, + } + + if self.count_penalty is not None: + base_params["count_penalty"] = self.count_penalty.to_dict() + + if self.frequency_penalty is not None: + base_params["frequency_penalty"] = self.frequency_penalty.to_dict() + + if self.presence_penalty is not None: + base_params["presence_penalty"] = self.presence_penalty.to_dict() + + return base_params + + def _build_params_for_request( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Mapping[str, Any]: + params = {} + system, ai21_messages = _convert_messages_to_ai21_messages(messages) + + if stop is not None: + if "stop" in kwargs: + raise ValueError("stop is defined in both stop and kwargs") + params["stop_sequences"] = stop + + return { + "system": system or "", + "messages": ai21_messages, + **self._default_params, + **params, + **kwargs, + } + def _generate( self, messages: List[BaseMessage], @@ -136,24 +181,9 @@ class ChatAI21(BaseChatModel, AI21Base): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - system, ai21_messages = _convert_messages_to_ai21_messages(messages) + params = self._build_params_for_request(messages=messages, stop=stop, **kwargs) - response = self.client.chat.create( - model=self.model, - messages=ai21_messages, - system=system or "", - num_results=self.num_results, - temperature=self.temperature, - max_tokens=self.max_tokens, - min_tokens=self.min_tokens, - top_p=self.top_p, - top_k_return=self.top_k_return, - stop_sequences=stop, - frequency_penalty=self.frequency_penalty, - presence_penalty=self.presence_penalty, - count_penalty=self.count_penalty, - **kwargs, - ) + response = self.client.chat.create(**params) outputs = response.outputs message = AIMessage(content=outputs[0].text) diff --git a/libs/partners/ai21/langchain_ai21/llms.py b/libs/partners/ai21/langchain_ai21/llms.py index 27a8121bbe1..0cba917bd53 100644 --- a/libs/partners/ai21/langchain_ai21/llms.py +++ b/libs/partners/ai21/langchain_ai21/llms.py @@ -3,10 +3,11 @@ from functools import partial from typing import ( Any, List, + Mapping, Optional, ) -from ai21.models import CompletionsResponse, Penalty +from ai21.models import CompletionsResponse from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -47,16 +48,16 @@ class AI21LLM(BaseLLM, AI21Base): top_p: float = 1 """A value controlling the diversity of the model's responses.""" - top_k_returns: int = 0 + top_k_return: int = 0 """The number of top-scoring tokens to consider for each generation step.""" - frequency_penalty: Optional[Penalty] = None + frequency_penalty: Optional[Any] = None """A penalty applied to tokens that are frequently generated.""" - presence_penalty: Optional[Penalty] = None + presence_penalty: Optional[Any] = None """ A penalty applied to tokens that are already present in the prompt.""" - count_penalty: Optional[Penalty] = None + count_penalty: Optional[Any] = None """A penalty applied to tokens based on their frequency in the generated responses.""" @@ -73,6 +74,51 @@ class AI21LLM(BaseLLM, AI21Base): """Return type of LLM.""" return "ai21-llm" + @property + def _default_params(self) -> Mapping[str, Any]: + base_params = { + "model": self.model, + "num_results": self.num_results, + "max_tokens": self.max_tokens, + "min_tokens": self.min_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k_return": self.top_k_return, + } + + if self.count_penalty is not None: + base_params["count_penalty"] = self.count_penalty.to_dict() + + if self.custom_model is not None: + base_params["custom_model"] = self.custom_model + + if self.epoch is not None: + base_params["epoch"] = self.epoch + + if self.frequency_penalty is not None: + base_params["frequency_penalty"] = self.frequency_penalty.to_dict() + + if self.presence_penalty is not None: + base_params["presence_penalty"] = self.presence_penalty.to_dict() + + return base_params + + def _build_params_for_request( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Mapping[str, Any]: + params = {} + + if stop is not None: + if "stop" in kwargs: + raise ValueError("stop is defined in both stop and kwargs") + params["stop_sequences"] = stop + + return { + **self._default_params, + **params, + **kwargs, + } + def _generate( self, prompts: List[str], @@ -83,10 +129,10 @@ class AI21LLM(BaseLLM, AI21Base): generations: List[List[Generation]] = [] token_count = 0 + params = self._build_params_for_request(stop=stop, **kwargs) + for prompt in prompts: - response = self._invoke_completion( - prompt=prompt, model=self.model, stop_sequences=stop, **kwargs - ) + response = self._invoke_completion(prompt=prompt, **params) generation = self._response_to_generation(response) generations.append(generation) token_count += self.client.count_tokens(prompt) @@ -109,25 +155,11 @@ class AI21LLM(BaseLLM, AI21Base): def _invoke_completion( self, prompt: str, - model: str, - stop_sequences: Optional[List[str]] = None, **kwargs: Any, ) -> CompletionsResponse: return self.client.completion.create( prompt=prompt, - model=model, - max_tokens=self.max_tokens, - num_results=self.num_results, - min_tokens=self.min_tokens, - temperature=self.temperature, - top_p=self.top_p, - top_k_return=self.top_k_returns, - custom_model=self.custom_model, - stop_sequences=stop_sequences, - frequency_penalty=self.frequency_penalty, - presence_penalty=self.presence_penalty, - count_penalty=self.count_penalty, - epoch=self.epoch, + **kwargs, ) def _response_to_generation( diff --git a/libs/partners/ai21/poetry.lock b/libs/partners/ai21/poetry.lock index bb20fb898dd..c6bd28c0898 100644 --- a/libs/partners/ai21/poetry.lock +++ b/libs/partners/ai21/poetry.lock @@ -1,20 +1,21 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "ai21" -version = "2.0.5" +version = "2.1.2" description = "" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "ai21-2.0.5-py3-none-any.whl", hash = "sha256:b10b66d9a8a3a7a010e3f67d39049913ea38d23034a34d18940a1801412bae10"}, - {file = "ai21-2.0.5.tar.gz", hash = "sha256:3b0c933b7b7268f9d7615f6cab23febc7d78da887eaef002cdc24cd1e9971851"}, + {file = "ai21-2.1.2-py3-none-any.whl", hash = "sha256:5ca1b5e1f11dc52fd3e894edb288634572ee6f13d8fe456c66a1825812067548"}, + {file = "ai21-2.1.2.tar.gz", hash = "sha256:8968a2b4a98fdc5b1bca4a9c856a903fa874d0762a2570741efa071b65a1accd"}, ] [package.dependencies] ai21-tokenizer = ">=0.3.9,<0.4.0" dataclasses-json = ">=0.6.3,<0.7.0" requests = ">=2.31.0,<3.0.0" +typing-extensions = ">=4.9.0,<5.0.0" [package.extras] aws = ["boto3 (>=1.28.82,<2.0.0)"] @@ -299,7 +300,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.27" +version = "0.1.30" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -725,7 +726,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1009,4 +1009,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "f8512fa6e745dc32b49132bd7327ec204822c861c113b57db21f1f8cf2fac881" +content-hash = "3073522be06765f2acb7efea6ed1fcc49eaa05e82534d96fa914899dbbbb541f" diff --git a/libs/partners/ai21/pyproject.toml b/libs/partners/ai21/pyproject.toml index c448bc532a8..31df1cd978d 100644 --- a/libs/partners/ai21/pyproject.toml +++ b/libs/partners/ai21/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-ai21" -version = "0.1.1" +version = "0.1.2" description = "An integration package connecting AI21 and LangChain" authors = [] readme = "README.md" @@ -8,7 +8,7 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.22" -ai21 = "2.0.5" +ai21 = "^2.1.2" [tool.poetry.group.test] optional = true diff --git a/libs/partners/ai21/tests/unit_tests/conftest.py b/libs/partners/ai21/tests/unit_tests/conftest.py index ba5ee070d55..858417a7da7 100644 --- a/libs/partners/ai21/tests/unit_tests/conftest.py +++ b/libs/partners/ai21/tests/unit_tests/conftest.py @@ -1,3 +1,4 @@ +import os from contextlib import contextmanager from typing import Generator from unittest.mock import Mock @@ -31,11 +32,30 @@ BASIC_EXAMPLE_LLM_PARAMETERS = { "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), "count_penalty": Penalty( - scale=0.2, apply_to_punctuation=True, apply_to_emojis=True + scale=0.2, + apply_to_punctuation=True, + apply_to_emojis=True, ), } +BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = { + "num_results": 3, + "max_tokens": 20, + "min_tokens": 10, + "temperature": 0.5, + "top_p": 0.5, + "top_k_return": 0, + "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), + "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), + "count_penalty": Penalty( + scale=0.2, + apply_to_punctuation=True, + apply_to_emojis=True, + ).to_dict(), +} + + @pytest.fixture def mocked_completion_response(mocker: MockerFixture) -> Mock: mocked_response = mocker.MagicMock(spec=CompletionsResponse) @@ -86,10 +106,12 @@ def temporarily_unset_api_key() -> Generator: """ api_key = AI21EnvConfig.api_key AI21EnvConfig.api_key = None + os.environ.pop("AI21_API_KEY", None) yield if api_key is not None: AI21EnvConfig.api_key = api_key + os.environ["AI21_API_KEY"] = api_key @pytest.fixture diff --git a/libs/partners/ai21/tests/unit_tests/test_chat_models.py b/libs/partners/ai21/tests/unit_tests/test_chat_models.py index 83eb06bc457..f95c73db904 100644 --- a/libs/partners/ai21/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/unit_tests/test_chat_models.py @@ -22,6 +22,7 @@ from langchain_ai21.chat_models import ( ) from tests.unit_tests.conftest import ( BASIC_EXAMPLE_LLM_PARAMETERS, + BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, DUMMY_API_KEY, temporarily_unset_api_key, ) @@ -46,7 +47,7 @@ def test_initialization__when_custom_parameters_in_init() -> None: min_tokens = 20 temperature = 0.1 top_p = 0.1 - top_k_returns = 0 + top_k_return = 0 frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) @@ -59,7 +60,7 @@ def test_initialization__when_custom_parameters_in_init() -> None: min_tokens=min_tokens, temperature=temperature, top_p=top_p, - top_k_returns=top_k_returns, + top_k_return=top_k_return, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, count_penalty=count_penalty, @@ -70,7 +71,7 @@ def test_initialization__when_custom_parameters_in_init() -> None: assert llm.min_tokens == min_tokens assert llm.temperature == temperature assert llm.top_p == top_p - assert llm.top_k_return == top_k_returns + assert llm.top_k_return == top_k_return assert llm.frequency_penalty == frequency_penalty assert llm.presence_penalty == presence_penalty assert count_penalty == count_penalty @@ -180,14 +181,14 @@ def test_invoke(mock_client_with_chat: Mock) -> None: client=mock_client_with_chat, **BASIC_EXAMPLE_LLM_PARAMETERS, ) - llm.invoke(input=chat_input, config=dict(tags=["foo"])) + llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"]) mock_client_with_chat.chat.create.assert_called_once_with( model="j2-ultra", messages=[ChatMessage(role=RoleType.USER, text=chat_input)], system="", - stop_sequences=None, - **BASIC_EXAMPLE_LLM_PARAMETERS, + stop_sequences=["\n"], + **BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, ) @@ -223,8 +224,7 @@ def test_generate(mock_client_with_chat: Mock) -> None: ChatMessage(role=RoleType.USER, text=str(messages0[2].content)), ], system="", - stop_sequences=None, - **BASIC_EXAMPLE_LLM_PARAMETERS, + **BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, ), call( model="j2-ultra", @@ -232,8 +232,7 @@ def test_generate(mock_client_with_chat: Mock) -> None: ChatMessage(role=RoleType.USER, text=str(messages1[1].content)), ], system="system message", - stop_sequences=None, - **BASIC_EXAMPLE_LLM_PARAMETERS, + **BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, ), ] ) diff --git a/libs/partners/ai21/tests/unit_tests/test_llms.py b/libs/partners/ai21/tests/unit_tests/test_llms.py index a82240bea5d..2c47ec234ac 100644 --- a/libs/partners/ai21/tests/unit_tests/test_llms.py +++ b/libs/partners/ai21/tests/unit_tests/test_llms.py @@ -10,6 +10,7 @@ from ai21.models import ( from langchain_ai21 import AI21LLM from tests.unit_tests.conftest import ( BASIC_EXAMPLE_LLM_PARAMETERS, + BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, DUMMY_API_KEY, temporarily_unset_api_key, ) @@ -42,7 +43,7 @@ def test_initialization__when_custom_parameters_to_init() -> None: min_tokens=10, temperature=0.5, top_p=0.5, - top_k_returns=0, + top_k_return=0, stop_sequences=["\n"], frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), @@ -93,7 +94,7 @@ def test_generate(mock_client_with_completion: Mock) -> None: custom_model=custom_model, stop_sequences=stop, epoch=epoch, - **BASIC_EXAMPLE_LLM_PARAMETERS, + **BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, ), call( prompt=prompt1, @@ -101,7 +102,7 @@ def test_generate(mock_client_with_completion: Mock) -> None: custom_model=custom_model, stop_sequences=stop, epoch=epoch, - **BASIC_EXAMPLE_LLM_PARAMETERS, + **BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT, ), ] )