diff --git a/libs/partners/ai21/README.md b/libs/partners/ai21/README.md index df66bc72111..7cb09fd9334 100644 --- a/libs/partners/ai21/README.md +++ b/libs/partners/ai21/README.md @@ -150,4 +150,68 @@ from langchain_ai21 import AI21SemanticTextSplitter splitter = AI21SemanticTextSplitter() response = splitter.split_text("Your text") +``` + +## Tool calls + +### Function calling + +AI21 models incorporate the Function Calling feature to support custom user functions. The models generate structured +data that includes the function name and proposed arguments. This data empowers applications to call external APIs and +incorporate the resulting information into subsequent model prompts, enriching responses with real-time data and +context. Through function calling, users can access and utilize various services like transportation APIs and financial +data providers to obtain more accurate and relevant answers. Here is an example of how to use function calling +with AI21 models in LangChain: + +```python +import os +from getpass import getpass +from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage +from langchain_core.tools import tool +from langchain_ai21.chat_models import ChatAI21 +from langchain_core.utils.function_calling import convert_to_openai_tool + +os.environ["AI21_API_KEY"] = getpass() + +@tool +def get_weather(location: str, date: str) -> str: + """“Provide the weather for the specified location on the given date.”""" + if location == "New York" and date == "2024-12-05": + return "25 celsius" + elif location == "New York" and date == "2024-12-06": + return "27 celsius" + elif location == "London" and date == "2024-12-05": + return "22 celsius" + return "32 celsius" + +llm = ChatAI21(model="jamba-1.5-mini") + +llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)]) + +chat_messages = [SystemMessage(content="You are a helpful assistant. You can use the provided tools " + "to assist with various tasks and provide accurate information")] + +human_messages = [ + HumanMessage(content="What is the forecast for the weather in New York on December 5, 2024?"), + HumanMessage(content="And what about the 2024-12-06?"), + HumanMessage(content="OK, thank you."), + HumanMessage(content="What is the expected weather in London on December 5, 2024?")] + + +for human_message in human_messages: + print(f"User: {human_message.content}") + chat_messages.append(human_message) + response = llm_with_tools.invoke(chat_messages) + chat_messages.append(response) + if response.tool_calls: + tool_call = response.tool_calls[0] + if tool_call["name"] == "get_weather": + weather = get_weather.invoke( + {"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]}) + chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"])) + llm_answer = llm_with_tools.invoke(chat_messages) + print(f"Assistant: {llm_answer.content}") + else: + print(f"Assistant: {response.content}") + ``` \ No newline at end of file diff --git a/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py b/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py index 699805f01ea..793ae241674 100644 --- a/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py +++ b/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py @@ -1,11 +1,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload +from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast, overload from ai21.models import ChatMessage as J2ChatMessage from ai21.models import RoleType -from ai21.models.chat import ChatCompletionChunk, ChatMessage +from ai21.models.chat import ( + AssistantMessage as AI21AssistantMessage, +) +from ai21.models.chat import ChatCompletionChunk, ChatMessageParam +from ai21.models.chat import ChatMessage as AI21ChatMessage +from ai21.models.chat import SystemMessage as AI21SystemMessage +from ai21.models.chat import ToolCall as AI21ToolCall +from ai21.models.chat import ToolFunction as AI21ToolFunction +from ai21.models.chat import ToolMessage as AI21ToolMessage +from ai21.models.chat import UserMessage as AI21UserMessage from ai21.stream.stream import Stream as AI21Stream from langchain_core.messages import ( AIMessage, @@ -13,11 +22,15 @@ from langchain_core.messages import ( BaseMessage, BaseMessageChunk, HumanMessage, + SystemMessage, + ToolCall, + ToolMessage, ) from langchain_core.messages.ai import UsageMetadata +from langchain_core.output_parsers.openai_tools import parse_tool_call from langchain_core.outputs import ChatGenerationChunk -_ChatMessageTypes = Union[ChatMessage, J2ChatMessage] +_ChatMessageTypes = Union[AI21ChatMessage, J2ChatMessage] _SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list." _ROLE_TYPE = Union[str, RoleType] @@ -40,20 +53,24 @@ class ChatAdapter(ABC): self, message: BaseMessage, ) -> _ChatMessageTypes: - content = cast(str, message.content) role = self._parse_role(message) - - return self._chat_message(role=role, content=content) + return self._chat_message(role=role, message=message) def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE: role = None + if isinstance(message, SystemMessage): + return RoleType.SYSTEM + if isinstance(message, HumanMessage): return RoleType.USER if isinstance(message, AIMessage): return RoleType.ASSISTANT + if isinstance(message, ToolMessage): + return RoleType.TOOL + if isinstance(self, J2ChatAdapter): if not role: raise ValueError( @@ -68,7 +85,7 @@ class ChatAdapter(ABC): def _chat_message( self, role: _ROLE_TYPE, - content: str, + message: BaseMessage, ) -> _ChatMessageTypes: pass @@ -130,9 +147,9 @@ class J2ChatAdapter(ChatAdapter): def _chat_message( self, role: _ROLE_TYPE, - content: str, + message: BaseMessage, ) -> J2ChatMessage: - return J2ChatMessage(role=RoleType(role), text=content) + return J2ChatMessage(role=RoleType(role), text=cast(str, message.content)) @overload def call( @@ -174,12 +191,65 @@ class JambaChatCompletionsAdapter(ChatAdapter): ], } + def _convert_lc_tool_calls_to_ai21_tool_calls( + self, tool_calls: List[ToolCall] + ) -> Optional[List[AI21ToolCall]]: + """ + Convert Langchain ToolCalls to AI21 ToolCalls. + """ + ai21_tool_calls: List[AI21ToolCall] = [] + for lc_tool_call in tool_calls: + if "id" not in lc_tool_call or not lc_tool_call["id"]: + raise ValueError("Tool call ID is missing or empty.") + + ai21_tool_call = AI21ToolCall( + id=lc_tool_call["id"], + type="function", + function=AI21ToolFunction( + name=lc_tool_call["name"], + arguments=str(lc_tool_call["args"]), + ), + ) + ai21_tool_calls.append(ai21_tool_call) + + return ai21_tool_calls + + def _get_content_as_string(self, base_message: BaseMessage) -> str: + if isinstance(base_message.content, str): + return base_message.content + elif isinstance(base_message.content, list): + return "\n".join(str(item) for item in base_message.content) + else: + raise ValueError("Unsupported content type") + def _chat_message( self, role: _ROLE_TYPE, - content: str, - ) -> ChatMessage: - return ChatMessage( + message: BaseMessage, + ) -> ChatMessageParam: + content = self._get_content_as_string(message) + + if isinstance(message, AIMessage): + return AI21AssistantMessage( + tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls( + message.tool_calls + ), + content=content or None, + ) + if isinstance(message, ToolMessage): + return AI21ToolMessage( + tool_call_id=message.tool_call_id, + content=content, + ) + if isinstance(message, HumanMessage): + return AI21UserMessage( + content=content, + ) + if isinstance(message, SystemMessage): + return AI21SystemMessage( + content=content, + ) + return AI21ChatMessage( role=role.value if isinstance(role, RoleType) else role, content=content, ) @@ -211,7 +281,18 @@ class JambaChatCompletionsAdapter(ChatAdapter): if stream: return self._stream_response(response) - return [AIMessage(choice.message.content) for choice in response.choices] + ai_messages: List[BaseMessage] = [] + for message in response.choices: + if message.message.tool_calls: + tool_calls = [ + parse_tool_call(tool_call.model_dump(), return_id=True) + for tool_call in message.message.tool_calls + ] + ai_messages.append(AIMessage("", tool_calls=tool_calls)) + else: + ai_messages.append(AIMessage(message.message.content)) + + return ai_messages def _stream_response( self, diff --git a/libs/partners/ai21/langchain_ai21/chat_models.py b/libs/partners/ai21/langchain_ai21/chat_models.py index 374f262e101..a1ee06c4bae 100644 --- a/libs/partners/ai21/langchain_ai21/chat_models.py +++ b/libs/partners/ai21/langchain_ai21/chat_models.py @@ -1,11 +1,23 @@ import asyncio from functools import partial -from typing import Any, Dict, Iterator, List, Mapping, Optional +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Type, + Union, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, LangSmithParams, @@ -16,6 +28,9 @@ from langchain_core.messages import ( ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import root_validator +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_ai21.ai21_base import AI21Base from langchain_ai21.chat.chat_adapter import ChatAdapter @@ -48,14 +63,14 @@ class ChatAI21(BaseChatModel, AI21Base): stop: Optional[List[str]] = None """Default stop sequences.""" - max_tokens: int = 16 + max_tokens: int = 512 """The maximum number of tokens to generate for each response.""" min_tokens: int = 0 """The minimum number of tokens to generate for each response. _Not supported for all models._""" - temperature: float = 0.7 + temperature: float = 0.4 """A value controlling the "creativity" of the model's responses.""" top_p: float = 1 @@ -246,3 +261,11 @@ class ChatAI21(BaseChatModel, AI21Base): ) return message.content + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) diff --git a/libs/partners/ai21/poetry.lock b/libs/partners/ai21/poetry.lock index 3f7f3f5ded1..1451df90819 100644 --- a/libs/partners/ai21/poetry.lock +++ b/libs/partners/ai21/poetry.lock @@ -1,35 +1,36 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "ai21" -version = "2.7.0" +version = "2.14.1" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"}, - {file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"}, + {file = "ai21-2.14.1-py3-none-any.whl", hash = "sha256:618c0b5c025123c703645258472330a07ae2de17020438f6a33b29668275995c"}, + {file = "ai21-2.14.1.tar.gz", hash = "sha256:05d9626b82206e0a5be43d17c39d4e0c7b51c2b7634d2ea38a2c698ac3e2fd5b"}, ] [package.dependencies] -ai21-tokenizer = ">=0.11.0,<1.0.0" -dataclasses-json = ">=0.6.3,<0.7.0" +ai21-tokenizer = ">=0.12.0,<1.0.0" httpx = ">=0.27.0,<0.28.0" +pydantic = ">=1.9.0,<3.0.0" tenacity = ">=8.3.0,<9.0.0" typing-extensions = ">=4.9.0,<5.0.0" [package.extras] aws = ["boto3 (>=1.28.82,<2.0.0)"] +vertex = ["google-auth (>=2.31.0,<3.0.0)"] [[package]] name = "ai21-tokenizer" -version = "0.11.2" +version = "0.12.0" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"}, - {file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"}, + {file = "ai21_tokenizer-0.12.0-py3-none-any.whl", hash = "sha256:7fd37b9093894b30b0f200e5f44fc8fb8772e2b272ef71b6d73722b4696e63c4"}, + {file = "ai21_tokenizer-0.12.0.tar.gz", hash = "sha256:d2a5b17789d21572504b7693148bf66e692bdb3ab563023dbcbee340bcbd11c6"}, ] [package.dependencies] @@ -211,21 +212,6 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "dataclasses-json" -version = "0.6.6" -description = "Easily serialize dataclasses to and from JSON." -optional = false -python-versions = "<4.0,>=3.7" -files = [ - {file = "dataclasses_json-0.6.6-py3-none-any.whl", hash = "sha256:e54c5c87497741ad454070ba0ed411523d46beb5da102e221efb873801b0ba85"}, - {file = "dataclasses_json-0.6.6.tar.gz", hash = "sha256:0c09827d26fffda27f1be2fed7a7a01a29c5ddcd2eb6393ad5ebf9d77e9deae8"}, -] - -[package.dependencies] -marshmallow = ">=3.18.0,<4.0.0" -typing-inspect = ">=0.4.0,<1" - [[package]] name = "exceptiongroup" version = "1.2.1" @@ -448,7 +434,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.11" +version = "0.2.33" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -465,6 +451,7 @@ pydantic = [ ] PyYAML = ">=5.3" tenacity = "^8.1.0,!=8.4.0" +typing-extensions = ">=4.7" [package.source] type = "directory" @@ -490,7 +477,7 @@ url = "../../standard-tests" [[package]] name = "langchain-text-splitters" -version = "0.2.2" +version = "0.2.3" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.8.1,<4.0" @@ -523,25 +510,6 @@ pydantic = [ ] requests = ">=2,<3" -[[package]] -name = "marshmallow" -version = "3.21.2" -description = "A lightweight library for converting complex datatypes to and from native Python datatypes." -optional = false -python-versions = ">=3.8" -files = [ - {file = "marshmallow-3.21.2-py3-none-any.whl", hash = "sha256:70b54a6282f4704d12c0a41599682c5c5450e843b9ec406308653b47c59648a1"}, - {file = "marshmallow-3.21.2.tar.gz", hash = "sha256:82408deadd8b33d56338d2182d455db632c6313aa2af61916672146bb32edc56"}, -] - -[package.dependencies] -packaging = ">=17.0" - -[package.extras] -dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] -docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] -tests = ["pytest", "pytz", "simplejson"] - [[package]] name = "mypy" version = "1.10.1" @@ -1257,21 +1225,6 @@ files = [ {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, ] -[[package]] -name = "typing-inspect" -version = "0.9.0" -description = "Runtime inspection utilities for typing module." -optional = false -python-versions = "*" -files = [ - {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, - {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, -] - -[package.dependencies] -mypy-extensions = ">=0.3.0" -typing-extensions = ">=3.7.4" - [[package]] name = "urllib3" version = "2.2.1" @@ -1336,4 +1289,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "9dee8f52fd10c8ffe640c33620d7a33daa989ba38f56f7a33a2e98a734457015" +content-hash = "32e1777e151eef2eb3775c4f1707ec4e80241a08fd4beed2d60a36f58d68dfea" diff --git a/libs/partners/ai21/pyproject.toml b/libs/partners/ai21/pyproject.toml index fe322f5f2e9..9b8806bb5e1 100644 --- a/libs/partners/ai21/pyproject.toml +++ b/libs/partners/ai21/pyproject.toml @@ -22,7 +22,7 @@ disallow_untyped_defs = "True" python = ">=3.8.1,<4.0" langchain-core = "^0.2.4" langchain-text-splitters = "^0.2.0" -ai21 = "^2.7.0" +ai21 = "^2.14.1" [tool.ruff.lint] select = [ "E", "F", "I",] diff --git a/libs/partners/ai21/tests/integration_tests/test_chat_models.py b/libs/partners/ai21/tests/integration_tests/test_chat_models.py index d0e6a7e09bb..8f5b72cbb09 100644 --- a/libs/partners/ai21/tests/integration_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/integration_tests/test_chat_models.py @@ -1,12 +1,25 @@ """Test ChatAI21 chat model.""" import pytest -from langchain_core.messages import AIMessageChunk, HumanMessage +from langchain_core.messages import ( + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.outputs import ChatGeneration from langchain_core.rate_limiters import InMemoryRateLimiter +from langchain_core.tools import tool +from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_ai21.chat_models import ChatAI21 -from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME +from tests.unit_tests.conftest import ( + J2_CHAT_MODEL_NAME, + JAMBA_1_5_LARGE_CHAT_MODEL_NAME, + JAMBA_1_5_MINI_CHAT_MODEL_NAME, + JAMBA_CHAT_MODEL_NAME, + JAMBA_FAMILY_MODEL_NAMES, +) rate_limiter = InMemoryRateLimiter(requests_per_second=0.5) @@ -15,11 +28,15 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.5) ids=[ "when_j2_model", "when_jamba_model", + "when_jamba1.5-mini_model", + "when_jamba1.5-large_model", ], argnames=["model"], argvalues=[ (J2_CHAT_MODEL_NAME,), (JAMBA_CHAT_MODEL_NAME,), + (JAMBA_1_5_MINI_CHAT_MODEL_NAME,), + (JAMBA_1_5_LARGE_CHAT_MODEL_NAME,), ], ) def test_invoke(model: str) -> None: @@ -36,6 +53,10 @@ def test_invoke(model: str) -> None: "when_j2_model_num_results_is_3", "when_jamba_model_n_is_1", "when_jamba_model_n_is_3", + "when_jamba1.5_mini_model_n_is_1", + "when_jamba1.5_mini_model_n_is_3", + "when_jamba1.5_large_model_n_is_1", + "when_jamba1.5_large_model_n_is_3", ], argnames=["model", "num_results"], argvalues=[ @@ -43,12 +64,16 @@ def test_invoke(model: str) -> None: (J2_CHAT_MODEL_NAME, 3), (JAMBA_CHAT_MODEL_NAME, 1), (JAMBA_CHAT_MODEL_NAME, 3), + (JAMBA_1_5_MINI_CHAT_MODEL_NAME, 1), + (JAMBA_1_5_MINI_CHAT_MODEL_NAME, 3), + (JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 1), + (JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 3), ], ) def test_generation(model: str, num_results: int) -> None: """Test generation with multiple models and different result counts.""" # Determine the configuration key based on the model type - config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results" + config_key = "n" if model in JAMBA_FAMILY_MODEL_NAMES else "num_results" # Create the model instance using the appropriate key for the result count llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type] @@ -69,11 +94,15 @@ def test_generation(model: str, num_results: int) -> None: ids=[ "when_j2_model", "when_jamba_model", + "when_jamba1.5_mini_model", + "when_jamba1.5_large_model", ], argnames=["model"], argvalues=[ (J2_CHAT_MODEL_NAME,), (JAMBA_CHAT_MODEL_NAME,), + (JAMBA_1_5_MINI_CHAT_MODEL_NAME,), + (JAMBA_1_5_LARGE_CHAT_MODEL_NAME,), ], ) async def test_ageneration(model: str) -> None: @@ -92,7 +121,7 @@ async def test_ageneration(model: str) -> None: def test__chat_stream() -> None: - llm = ChatAI21(model="jamba-instruct") # type: ignore[call-arg] + llm = ChatAI21(model="jamba-1.5-mini") # type: ignore[call-arg] message = HumanMessage(content="What is the meaning of life?") for chunk in llm.stream([message]): @@ -107,3 +136,53 @@ def test__j2_chat_stream__should_raise_error() -> None: with pytest.raises(NotImplementedError): for _ in llm.stream([message]): pass + + +@pytest.mark.parametrize( + ids=[ + "when_jamba1.5_mini_model", + "when_jamba1.5_large_model", + ], + argnames=["model"], + argvalues=[ + (JAMBA_1_5_MINI_CHAT_MODEL_NAME,), + (JAMBA_1_5_LARGE_CHAT_MODEL_NAME,), + ], +) +def test_tool_calls(model: str) -> None: + @tool + def get_weather(location: str, date: str) -> str: + """“Provide the weather for the specified location on the given date.”""" + if location == "New York" and date == "2024-12-05": + return "25 celsius" + return "32 celsius" + + llm = ChatAI21(model=model, temperature=0) # type: ignore[call-arg] + llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)]) + + chat_messages = [ + SystemMessage( + content="You are a helpful assistant. " + "You can use the provided tools " + "to assist with various tasks and provide " + "accurate information" + ), + HumanMessage( + content="What is the forecast for the weather " + "in New York on December 5, 2024?" + ), + ] + + response = llm_with_tools.invoke(chat_messages) + chat_messages.append(response) + assert response.tool_calls is not None # type: ignore[attr-defined] + tool_call = response.tool_calls[0] # type: ignore[attr-defined] + assert tool_call["name"] == "get_weather" + + weather = get_weather.invoke( # type: ignore[attr-defined] + {"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]} + ) + chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"])) + llm_answer = llm_with_tools.invoke(chat_messages) + content = llm_answer.content.lower() # type: ignore[union-attr] + assert "new york" in content and "25" in content and "celsius" in content diff --git a/libs/partners/ai21/tests/integration_tests/test_standard.py b/libs/partners/ai21/tests/integration_tests/test_standard.py index 5896573102d..dcd39cf5b36 100644 --- a/libs/partners/ai21/tests/integration_tests/test_standard.py +++ b/libs/partners/ai21/tests/integration_tests/test_standard.py @@ -1,7 +1,7 @@ """Standard LangChain interface tests""" import time -from typing import Type +from typing import Optional, Type import pytest from langchain_core.language_models import BaseChatModel @@ -28,6 +28,8 @@ class BaseTestAI21(ChatModelIntegrationTests): class TestAI21J2(BaseTestAI21): + has_tool_calling = False + @property def chat_model_params(self) -> dict: return { @@ -49,8 +51,23 @@ class TestAI21J2(BaseTestAI21): class TestAI21Jamba(BaseTestAI21): + has_tool_calling = False + @property def chat_model_params(self) -> dict: return { "model": "jamba-instruct-preview", } + + +class TestAI21Jamba1_5(BaseTestAI21): + @property + def tool_choice_value(self) -> Optional[str]: + """Value to use for tool choice when used in tests.""" + return "any" + + @property + def chat_model_params(self) -> dict: + return { + "model": "jamba-1.5-mini", + } diff --git a/libs/partners/ai21/tests/unit_tests/chat/test_chat_adapter.py b/libs/partners/ai21/tests/unit_tests/chat/test_chat_adapter.py index b75157b8c3f..d1abe578ab0 100644 --- a/libs/partners/ai21/tests/unit_tests/chat/test_chat_adapter.py +++ b/libs/partners/ai21/tests/unit_tests/chat/test_chat_adapter.py @@ -3,12 +3,21 @@ from typing import List import pytest from ai21.models import ChatMessage as J2ChatMessage from ai21.models import RoleType -from ai21.models.chat import ChatMessage +from ai21.models.chat import ( + AssistantMessage, + ChatMessage, + UserMessage, +) +from ai21.models.chat import ( + SystemMessage as AI21SystemMessage, +) +from ai21.models.chat import ToolMessage as AI21ToolMessage from langchain_core.messages import ( AIMessage, BaseMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.messages import ( ChatMessage as LangChainChatMessage, @@ -18,6 +27,8 @@ from langchain_ai21.chat.chat_adapter import ChatAdapter _J2_MODEL_NAME = "j2-ultra" _JAMBA_MODEL_NAME = "jamba-instruct-preview" +_JAMBA_1_5_MINI_MODEL_NAME = "jamba-1.5-mini" +_JAMBA_1_5_LARGE_MODEL_NAME = "jamba-1.5-large" @pytest.mark.parametrize( @@ -42,12 +53,14 @@ _JAMBA_MODEL_NAME = "jamba-instruct-preview" ( _JAMBA_MODEL_NAME, HumanMessage(content="Human Message Content"), - ChatMessage(role=RoleType.USER, content="Human Message Content"), + UserMessage(role="user", content="Human Message Content"), ), ( _JAMBA_MODEL_NAME, AIMessage(content="AI Message Content"), - ChatMessage(role=RoleType.ASSISTANT, content="AI Message Content"), + AssistantMessage( + role="assistant", content="AI Message Content", tool_calls=[] + ), ), ], ) @@ -69,7 +82,7 @@ def test_convert_message_to_ai21_message( argvalues=[ ( _J2_MODEL_NAME, - SystemMessage(content="System Message Content"), + AI21SystemMessage(content="System Message Content"), ), ( _J2_MODEL_NAME, @@ -95,6 +108,8 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except "when_first_message_is_system__should_return_system_j2_model", "when_all_messages_are_human_messages__should_return_system_none_jamba_model", "when_first_message_is_system__should_return_system_jamba_model", + "when_tool_calling_message__should_return_tool_jamba_mini_model", + "when_tool_calling_message__should_return_tool_jamba_large_model", ], argnames=["model", "messages", "expected_messages"], argvalues=[ @@ -142,12 +157,12 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except ], { "messages": [ - ChatMessage( - role=RoleType.USER, + UserMessage( + role="user", content="Human Message Content 1", ), - ChatMessage( - role=RoleType.USER, + UserMessage( + role="user", content="Human Message Content 2", ), ] @@ -161,8 +176,46 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except ], { "messages": [ - ChatMessage(role="system", content="System Message Content 1"), - ChatMessage(role="user", content="Human Message Content 1"), + AI21SystemMessage( + role="system", content="System Message Content 1" + ), + UserMessage(role="user", content="Human Message Content 1"), + ], + }, + ), + ( + _JAMBA_1_5_MINI_MODEL_NAME, + [ + ToolMessage( + content="42", + tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL", + ) + ], + { + "messages": [ + AI21ToolMessage( + role="tool", + tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL", + content="42", + ), + ], + }, + ), + ( + _JAMBA_1_5_LARGE_MODEL_NAME, + [ + ToolMessage( + content="42", + tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL", + ) + ], + { + "messages": [ + AI21ToolMessage( + role="tool", + tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL", + content="42", + ), ], }, ), diff --git a/libs/partners/ai21/tests/unit_tests/conftest.py b/libs/partners/ai21/tests/unit_tests/conftest.py index 9d95584daa4..f70b2cdaa0e 100644 --- a/libs/partners/ai21/tests/unit_tests/conftest.py +++ b/libs/partners/ai21/tests/unit_tests/conftest.py @@ -23,8 +23,16 @@ from pytest_mock import MockerFixture J2_CHAT_MODEL_NAME = "j2-ultra" JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview" +JAMBA_1_5_MINI_CHAT_MODEL_NAME = "jamba-1.5-mini" +JAMBA_1_5_LARGE_CHAT_MODEL_NAME = "jamba-1.5-large" DUMMY_API_KEY = "test_api_key" +JAMBA_FAMILY_MODEL_NAMES = [ + JAMBA_CHAT_MODEL_NAME, + JAMBA_1_5_MINI_CHAT_MODEL_NAME, + JAMBA_1_5_LARGE_CHAT_MODEL_NAME, +] + BASIC_EXAMPLE_LLM_PARAMETERS = { "num_results": 3, "max_tokens": 20, @@ -32,9 +40,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS = { "temperature": 0.5, "top_p": 0.5, "top_k_return": 0, - "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), - "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), - "count_penalty": Penalty( + "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg] + "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg] + "count_penalty": Penalty( # type: ignore[call-arg] scale=0.2, apply_to_punctuation=True, apply_to_emojis=True, @@ -48,9 +56,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = { "temperature": 0.5, "top_p": 0.5, "top_k_return": 0, - "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), - "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), - "count_penalty": Penalty( + "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg] + "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg] + "count_penalty": Penalty( # type: ignore[call-arg] scale=0.2, apply_to_punctuation=True, apply_to_emojis=True, @@ -59,7 +67,7 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = { } SEGMENTS = [ - Segment( + Segment( # type: ignore[call-arg] segment_type="normal_text", segment_text=( "The original full name of the franchise is Pocket Monsters " @@ -70,7 +78,7 @@ SEGMENTS = [ "in pronunciation." ), ), - Segment( + Segment( # type: ignore[call-arg] segment_type="normal_text", segment_text=( "Pokémon refers to both the franchise itself and the creatures " @@ -92,9 +100,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = { "temperature": 0.5, "top_p": 0.5, "top_k_return": 0, - "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), - "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), - "count_penalty": Penalty( + "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg] + "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg] + "count_penalty": Penalty( # type: ignore[call-arg] scale=0.2, apply_to_punctuation=True, apply_to_emojis=True, @@ -108,9 +116,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = { "temperature": 0.5, "top_p": 0.5, "top_k_return": 0, - "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), - "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), - "count_penalty": Penalty( + "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg] + "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg] + "count_penalty": Penalty( # type: ignore[call-arg] scale=0.2, apply_to_punctuation=True, apply_to_emojis=True, @@ -124,7 +132,7 @@ def mocked_completion_response(mocker: MockerFixture) -> Mock: mocked_response = mocker.MagicMock(spec=CompletionsResponse) mocked_response.prompt = "this is a test prompt" mocked_response.completions = [ - Completion( + Completion( # type: ignore[call-arg] data=CompletionData(text="test", tokens=[]), finish_reason=CompletionFinishReason(reason=None, length=None), ) @@ -152,7 +160,7 @@ def mock_client_with_chat(mocker: MockerFixture) -> Mock: mock_client = mocker.MagicMock(spec=AI21Client) mock_client.chat = mocker.MagicMock() - output = ChatOutput( + output = ChatOutput( # type: ignore[call-arg] text="Hello Pickle Rick!", role=RoleType.ASSISTANT, finish_reason=FinishReason(reason="testing"), @@ -178,7 +186,7 @@ def temporarily_unset_api_key() -> Generator: def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock: mock_client = mocker.MagicMock(spec=AI21Client) mock_client.answer = mocker.MagicMock() - mock_client.answer.create.return_value = AnswerResponse( + mock_client.answer.create.return_value = AnswerResponse( # type: ignore[call-arg] id="some_id", answer="some answer", answer_in_context=False, diff --git a/libs/partners/ai21/tests/unit_tests/test_chat_models.py b/libs/partners/ai21/tests/unit_tests/test_chat_models.py index f267b1ba5c1..4ad6b39a3a4 100644 --- a/libs/partners/ai21/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/unit_tests/test_chat_models.py @@ -45,9 +45,9 @@ def test_initialization__when_custom_parameters_in_init() -> None: temperature = 0.1 top_p = 0.1 top_k_return = 0 - frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) - presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) - count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) + frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) # type: ignore[call-arg] + presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) # type: ignore[call-arg] + count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) # type: ignore[call-arg] llm = ChatAI21( # type: ignore[call-arg] api_key=DUMMY_API_KEY, # type: ignore[arg-type] diff --git a/libs/partners/ai21/tests/unit_tests/test_embeddings.py b/libs/partners/ai21/tests/unit_tests/test_embeddings.py index 96e88b929f4..8c0062ad613 100644 --- a/libs/partners/ai21/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ai21/tests/unit_tests/test_embeddings.py @@ -17,9 +17,9 @@ _EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0] _EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse( results=[ - EmbedResult(_EXAMPLE_EMBEDDING_0), - EmbedResult(_EXAMPLE_EMBEDDING_1), - EmbedResult(_EXAMPLE_EMBEDDING_2), + EmbedResult(embedding=_EXAMPLE_EMBEDDING_0), + EmbedResult(embedding=_EXAMPLE_EMBEDDING_1), + EmbedResult(embedding=_EXAMPLE_EMBEDDING_2), ], id="test_id", ) diff --git a/libs/partners/ai21/tests/unit_tests/test_llms.py b/libs/partners/ai21/tests/unit_tests/test_llms.py index e862c81373b..648f2ef27d4 100644 --- a/libs/partners/ai21/tests/unit_tests/test_llms.py +++ b/libs/partners/ai21/tests/unit_tests/test_llms.py @@ -49,9 +49,9 @@ def test_initialization__when_custom_parameters_to_init() -> None: top_p=0.5, top_k_return=0, stop_sequences=["\n"], - frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), - presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), - count_penalty=Penalty( + frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg] + presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg] + count_penalty=Penalty( # type: ignore[call-arg] scale=0.2, apply_to_punctuation=True, apply_to_emojis=True ), custom_model="test_model",