AI21: tools calling support in Langchain (#25635)

This pull request introduces support for the AI21 tools calling feature,
available by the Jamba-1.5 models. When Jamba-1.5 detects the necessity
to invoke a provided tool, as indicated by the 'tools' parameter passed
to the model:

```
class ToolDefinition(TypedDict, total=False):
    type: Required[Literal["function"]]
    function: Required[FunctionToolDefinition]

class FunctionToolDefinition(TypedDict, total=False):
    name: Required[str]
    description: str
    parameters: ToolParameters

class ToolParameters(TypedDict, total=False):
    type: Literal["object"]
    properties: Required[Dict[str, Any]]
    required: List[str]
```

It will respond with a list of tool calls structured as follows:

```
class ToolCall(AI21BaseModel):
    id: str
    function: ToolFunction
    type: Literal["function"] = "function"

class ToolFunction(AI21BaseModel):
    name: str
    arguments: str
```

This pull request incorporates the necessary modifications to integrate
this functionality into the ai21-langchain library.

---------

Co-authored-by: asafg <asafg@ai21.com>
Co-authored-by: pazshalev <111360591+pazshalev@users.noreply.github.com>
Co-authored-by: Paz Shalev <pazs@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
amirai21 2024-08-26 20:50:30 +03:00 committed by GitHub
parent a566a15930
commit 17dffd9741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 397 additions and 119 deletions

View File

@ -150,4 +150,68 @@ from langchain_ai21 import AI21SemanticTextSplitter
splitter = AI21SemanticTextSplitter() splitter = AI21SemanticTextSplitter()
response = splitter.split_text("Your text") response = splitter.split_text("Your text")
```
## Tool calls
### Function calling
AI21 models incorporate the Function Calling feature to support custom user functions. The models generate structured
data that includes the function name and proposed arguments. This data empowers applications to call external APIs and
incorporate the resulting information into subsequent model prompts, enriching responses with real-time data and
context. Through function calling, users can access and utilize various services like transportation APIs and financial
data providers to obtain more accurate and relevant answers. Here is an example of how to use function calling
with AI21 models in LangChain:
```python
import os
from getpass import getpass
from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
from langchain_core.tools import tool
from langchain_ai21.chat_models import ChatAI21
from langchain_core.utils.function_calling import convert_to_openai_tool
os.environ["AI21_API_KEY"] = getpass()
@tool
def get_weather(location: str, date: str) -> str:
"""“Provide the weather for the specified location on the given date.”"""
if location == "New York" and date == "2024-12-05":
return "25 celsius"
elif location == "New York" and date == "2024-12-06":
return "27 celsius"
elif location == "London" and date == "2024-12-05":
return "22 celsius"
return "32 celsius"
llm = ChatAI21(model="jamba-1.5-mini")
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
chat_messages = [SystemMessage(content="You are a helpful assistant. You can use the provided tools "
"to assist with various tasks and provide accurate information")]
human_messages = [
HumanMessage(content="What is the forecast for the weather in New York on December 5, 2024?"),
HumanMessage(content="And what about the 2024-12-06?"),
HumanMessage(content="OK, thank you."),
HumanMessage(content="What is the expected weather in London on December 5, 2024?")]
for human_message in human_messages:
print(f"User: {human_message.content}")
chat_messages.append(human_message)
response = llm_with_tools.invoke(chat_messages)
chat_messages.append(response)
if response.tool_calls:
tool_call = response.tool_calls[0]
if tool_call["name"] == "get_weather":
weather = get_weather.invoke(
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]})
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
llm_answer = llm_with_tools.invoke(chat_messages)
print(f"Assistant: {llm_answer.content}")
else:
print(f"Assistant: {response.content}")
``` ```

View File

@ -1,11 +1,20 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast, overload
from ai21.models import ChatMessage as J2ChatMessage from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType from ai21.models import RoleType
from ai21.models.chat import ChatCompletionChunk, ChatMessage from ai21.models.chat import (
AssistantMessage as AI21AssistantMessage,
)
from ai21.models.chat import ChatCompletionChunk, ChatMessageParam
from ai21.models.chat import ChatMessage as AI21ChatMessage
from ai21.models.chat import SystemMessage as AI21SystemMessage
from ai21.models.chat import ToolCall as AI21ToolCall
from ai21.models.chat import ToolFunction as AI21ToolFunction
from ai21.models.chat import ToolMessage as AI21ToolMessage
from ai21.models.chat import UserMessage as AI21UserMessage
from ai21.stream.stream import Stream as AI21Stream from ai21.stream.stream import Stream as AI21Stream
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -13,11 +22,15 @@ from langchain_core.messages import (
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
HumanMessage, HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers.openai_tools import parse_tool_call
from langchain_core.outputs import ChatGenerationChunk from langchain_core.outputs import ChatGenerationChunk
_ChatMessageTypes = Union[ChatMessage, J2ChatMessage] _ChatMessageTypes = Union[AI21ChatMessage, J2ChatMessage]
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list." _SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
_ROLE_TYPE = Union[str, RoleType] _ROLE_TYPE = Union[str, RoleType]
@ -40,20 +53,24 @@ class ChatAdapter(ABC):
self, self,
message: BaseMessage, message: BaseMessage,
) -> _ChatMessageTypes: ) -> _ChatMessageTypes:
content = cast(str, message.content)
role = self._parse_role(message) role = self._parse_role(message)
return self._chat_message(role=role, message=message)
return self._chat_message(role=role, content=content)
def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE: def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE:
role = None role = None
if isinstance(message, SystemMessage):
return RoleType.SYSTEM
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
return RoleType.USER return RoleType.USER
if isinstance(message, AIMessage): if isinstance(message, AIMessage):
return RoleType.ASSISTANT return RoleType.ASSISTANT
if isinstance(message, ToolMessage):
return RoleType.TOOL
if isinstance(self, J2ChatAdapter): if isinstance(self, J2ChatAdapter):
if not role: if not role:
raise ValueError( raise ValueError(
@ -68,7 +85,7 @@ class ChatAdapter(ABC):
def _chat_message( def _chat_message(
self, self,
role: _ROLE_TYPE, role: _ROLE_TYPE,
content: str, message: BaseMessage,
) -> _ChatMessageTypes: ) -> _ChatMessageTypes:
pass pass
@ -130,9 +147,9 @@ class J2ChatAdapter(ChatAdapter):
def _chat_message( def _chat_message(
self, self,
role: _ROLE_TYPE, role: _ROLE_TYPE,
content: str, message: BaseMessage,
) -> J2ChatMessage: ) -> J2ChatMessage:
return J2ChatMessage(role=RoleType(role), text=content) return J2ChatMessage(role=RoleType(role), text=cast(str, message.content))
@overload @overload
def call( def call(
@ -174,12 +191,65 @@ class JambaChatCompletionsAdapter(ChatAdapter):
], ],
} }
def _convert_lc_tool_calls_to_ai21_tool_calls(
self, tool_calls: List[ToolCall]
) -> Optional[List[AI21ToolCall]]:
"""
Convert Langchain ToolCalls to AI21 ToolCalls.
"""
ai21_tool_calls: List[AI21ToolCall] = []
for lc_tool_call in tool_calls:
if "id" not in lc_tool_call or not lc_tool_call["id"]:
raise ValueError("Tool call ID is missing or empty.")
ai21_tool_call = AI21ToolCall(
id=lc_tool_call["id"],
type="function",
function=AI21ToolFunction(
name=lc_tool_call["name"],
arguments=str(lc_tool_call["args"]),
),
)
ai21_tool_calls.append(ai21_tool_call)
return ai21_tool_calls
def _get_content_as_string(self, base_message: BaseMessage) -> str:
if isinstance(base_message.content, str):
return base_message.content
elif isinstance(base_message.content, list):
return "\n".join(str(item) for item in base_message.content)
else:
raise ValueError("Unsupported content type")
def _chat_message( def _chat_message(
self, self,
role: _ROLE_TYPE, role: _ROLE_TYPE,
content: str, message: BaseMessage,
) -> ChatMessage: ) -> ChatMessageParam:
return ChatMessage( content = self._get_content_as_string(message)
if isinstance(message, AIMessage):
return AI21AssistantMessage(
tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls(
message.tool_calls
),
content=content or None,
)
if isinstance(message, ToolMessage):
return AI21ToolMessage(
tool_call_id=message.tool_call_id,
content=content,
)
if isinstance(message, HumanMessage):
return AI21UserMessage(
content=content,
)
if isinstance(message, SystemMessage):
return AI21SystemMessage(
content=content,
)
return AI21ChatMessage(
role=role.value if isinstance(role, RoleType) else role, role=role.value if isinstance(role, RoleType) else role,
content=content, content=content,
) )
@ -211,7 +281,18 @@ class JambaChatCompletionsAdapter(ChatAdapter):
if stream: if stream:
return self._stream_response(response) return self._stream_response(response)
return [AIMessage(choice.message.content) for choice in response.choices] ai_messages: List[BaseMessage] = []
for message in response.choices:
if message.message.tool_calls:
tool_calls = [
parse_tool_call(tool_call.model_dump(), return_id=True)
for tool_call in message.message.tool_calls
]
ai_messages.append(AIMessage("", tool_calls=tool_calls))
else:
ai_messages.append(AIMessage(message.message.content))
return ai_messages
def _stream_response( def _stream_response(
self, self,

View File

@ -1,11 +1,23 @@
import asyncio import asyncio
from functools import partial from functools import partial
from typing import Any, Dict, Iterator, List, Mapping, Optional from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
BaseChatModel, BaseChatModel,
LangSmithParams, LangSmithParams,
@ -16,6 +28,9 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_ai21.ai21_base import AI21Base from langchain_ai21.ai21_base import AI21Base
from langchain_ai21.chat.chat_adapter import ChatAdapter from langchain_ai21.chat.chat_adapter import ChatAdapter
@ -48,14 +63,14 @@ class ChatAI21(BaseChatModel, AI21Base):
stop: Optional[List[str]] = None stop: Optional[List[str]] = None
"""Default stop sequences.""" """Default stop sequences."""
max_tokens: int = 16 max_tokens: int = 512
"""The maximum number of tokens to generate for each response.""" """The maximum number of tokens to generate for each response."""
min_tokens: int = 0 min_tokens: int = 0
"""The minimum number of tokens to generate for each response. """The minimum number of tokens to generate for each response.
_Not supported for all models._""" _Not supported for all models._"""
temperature: float = 0.7 temperature: float = 0.4
"""A value controlling the "creativity" of the model's responses.""" """A value controlling the "creativity" of the model's responses."""
top_p: float = 1 top_p: float = 1
@ -246,3 +261,11 @@ class ChatAI21(BaseChatModel, AI21Base):
) )
return message.content return message.content
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@ -1,35 +1,36 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "ai21" name = "ai21"
version = "2.7.0" version = "2.14.1"
description = "" description = ""
optional = false optional = false
python-versions = "<4.0,>=3.8" python-versions = "<4.0,>=3.8"
files = [ files = [
{file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"}, {file = "ai21-2.14.1-py3-none-any.whl", hash = "sha256:618c0b5c025123c703645258472330a07ae2de17020438f6a33b29668275995c"},
{file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"}, {file = "ai21-2.14.1.tar.gz", hash = "sha256:05d9626b82206e0a5be43d17c39d4e0c7b51c2b7634d2ea38a2c698ac3e2fd5b"},
] ]
[package.dependencies] [package.dependencies]
ai21-tokenizer = ">=0.11.0,<1.0.0" ai21-tokenizer = ">=0.12.0,<1.0.0"
dataclasses-json = ">=0.6.3,<0.7.0"
httpx = ">=0.27.0,<0.28.0" httpx = ">=0.27.0,<0.28.0"
pydantic = ">=1.9.0,<3.0.0"
tenacity = ">=8.3.0,<9.0.0" tenacity = ">=8.3.0,<9.0.0"
typing-extensions = ">=4.9.0,<5.0.0" typing-extensions = ">=4.9.0,<5.0.0"
[package.extras] [package.extras]
aws = ["boto3 (>=1.28.82,<2.0.0)"] aws = ["boto3 (>=1.28.82,<2.0.0)"]
vertex = ["google-auth (>=2.31.0,<3.0.0)"]
[[package]] [[package]]
name = "ai21-tokenizer" name = "ai21-tokenizer"
version = "0.11.2" version = "0.12.0"
description = "" description = ""
optional = false optional = false
python-versions = "<4.0,>=3.8" python-versions = "<4.0,>=3.8"
files = [ files = [
{file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"}, {file = "ai21_tokenizer-0.12.0-py3-none-any.whl", hash = "sha256:7fd37b9093894b30b0f200e5f44fc8fb8772e2b272ef71b6d73722b4696e63c4"},
{file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"}, {file = "ai21_tokenizer-0.12.0.tar.gz", hash = "sha256:d2a5b17789d21572504b7693148bf66e692bdb3ab563023dbcbee340bcbd11c6"},
] ]
[package.dependencies] [package.dependencies]
@ -211,21 +212,6 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "dataclasses-json"
version = "0.6.6"
description = "Easily serialize dataclasses to and from JSON."
optional = false
python-versions = "<4.0,>=3.7"
files = [
{file = "dataclasses_json-0.6.6-py3-none-any.whl", hash = "sha256:e54c5c87497741ad454070ba0ed411523d46beb5da102e221efb873801b0ba85"},
{file = "dataclasses_json-0.6.6.tar.gz", hash = "sha256:0c09827d26fffda27f1be2fed7a7a01a29c5ddcd2eb6393ad5ebf9d77e9deae8"},
]
[package.dependencies]
marshmallow = ">=3.18.0,<4.0.0"
typing-inspect = ">=0.4.0,<1"
[[package]] [[package]]
name = "exceptiongroup" name = "exceptiongroup"
version = "1.2.1" version = "1.2.1"
@ -448,7 +434,7 @@ files = [
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.11" version = "0.2.33"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -465,6 +451,7 @@ pydantic = [
] ]
PyYAML = ">=5.3" PyYAML = ">=5.3"
tenacity = "^8.1.0,!=8.4.0" tenacity = "^8.1.0,!=8.4.0"
typing-extensions = ">=4.7"
[package.source] [package.source]
type = "directory" type = "directory"
@ -490,7 +477,7 @@ url = "../../standard-tests"
[[package]] [[package]]
name = "langchain-text-splitters" name = "langchain-text-splitters"
version = "0.2.2" version = "0.2.3"
description = "LangChain text splitting utilities" description = "LangChain text splitting utilities"
optional = false optional = false
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
@ -523,25 +510,6 @@ pydantic = [
] ]
requests = ">=2,<3" requests = ">=2,<3"
[[package]]
name = "marshmallow"
version = "3.21.2"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false
python-versions = ">=3.8"
files = [
{file = "marshmallow-3.21.2-py3-none-any.whl", hash = "sha256:70b54a6282f4704d12c0a41599682c5c5450e843b9ec406308653b47c59648a1"},
{file = "marshmallow-3.21.2.tar.gz", hash = "sha256:82408deadd8b33d56338d2182d455db632c6313aa2af61916672146bb32edc56"},
]
[package.dependencies]
packaging = ">=17.0"
[package.extras]
dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"]
docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"]
tests = ["pytest", "pytz", "simplejson"]
[[package]] [[package]]
name = "mypy" name = "mypy"
version = "1.10.1" version = "1.10.1"
@ -1257,21 +1225,6 @@ files = [
{file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"},
] ]
[[package]]
name = "typing-inspect"
version = "0.9.0"
description = "Runtime inspection utilities for typing module."
optional = false
python-versions = "*"
files = [
{file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"},
{file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"},
]
[package.dependencies]
mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4"
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.2.1" version = "2.2.1"
@ -1336,4 +1289,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "9dee8f52fd10c8ffe640c33620d7a33daa989ba38f56f7a33a2e98a734457015" content-hash = "32e1777e151eef2eb3775c4f1707ec4e80241a08fd4beed2d60a36f58d68dfea"

View File

@ -22,7 +22,7 @@ disallow_untyped_defs = "True"
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = "^0.2.4" langchain-core = "^0.2.4"
langchain-text-splitters = "^0.2.0" langchain-text-splitters = "^0.2.0"
ai21 = "^2.7.0" ai21 = "^2.14.1"
[tool.ruff.lint] [tool.ruff.lint]
select = [ "E", "F", "I",] select = [ "E", "F", "I",]

View File

@ -1,12 +1,25 @@
"""Test ChatAI21 chat model.""" """Test ChatAI21 chat model."""
import pytest import pytest
from langchain_core.messages import AIMessageChunk, HumanMessage from langchain_core.messages import (
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration from langchain_core.outputs import ChatGeneration
from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_ai21.chat_models import ChatAI21 from langchain_ai21.chat_models import ChatAI21
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME from tests.unit_tests.conftest import (
J2_CHAT_MODEL_NAME,
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
JAMBA_CHAT_MODEL_NAME,
JAMBA_FAMILY_MODEL_NAMES,
)
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5) rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
@ -15,11 +28,15 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
ids=[ ids=[
"when_j2_model", "when_j2_model",
"when_jamba_model", "when_jamba_model",
"when_jamba1.5-mini_model",
"when_jamba1.5-large_model",
], ],
argnames=["model"], argnames=["model"],
argvalues=[ argvalues=[
(J2_CHAT_MODEL_NAME,), (J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,), (JAMBA_CHAT_MODEL_NAME,),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
], ],
) )
def test_invoke(model: str) -> None: def test_invoke(model: str) -> None:
@ -36,6 +53,10 @@ def test_invoke(model: str) -> None:
"when_j2_model_num_results_is_3", "when_j2_model_num_results_is_3",
"when_jamba_model_n_is_1", "when_jamba_model_n_is_1",
"when_jamba_model_n_is_3", "when_jamba_model_n_is_3",
"when_jamba1.5_mini_model_n_is_1",
"when_jamba1.5_mini_model_n_is_3",
"when_jamba1.5_large_model_n_is_1",
"when_jamba1.5_large_model_n_is_3",
], ],
argnames=["model", "num_results"], argnames=["model", "num_results"],
argvalues=[ argvalues=[
@ -43,12 +64,16 @@ def test_invoke(model: str) -> None:
(J2_CHAT_MODEL_NAME, 3), (J2_CHAT_MODEL_NAME, 3),
(JAMBA_CHAT_MODEL_NAME, 1), (JAMBA_CHAT_MODEL_NAME, 1),
(JAMBA_CHAT_MODEL_NAME, 3), (JAMBA_CHAT_MODEL_NAME, 3),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 1),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 3),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 1),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 3),
], ],
) )
def test_generation(model: str, num_results: int) -> None: def test_generation(model: str, num_results: int) -> None:
"""Test generation with multiple models and different result counts.""" """Test generation with multiple models and different result counts."""
# Determine the configuration key based on the model type # Determine the configuration key based on the model type
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results" config_key = "n" if model in JAMBA_FAMILY_MODEL_NAMES else "num_results"
# Create the model instance using the appropriate key for the result count # Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type] llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
@ -69,11 +94,15 @@ def test_generation(model: str, num_results: int) -> None:
ids=[ ids=[
"when_j2_model", "when_j2_model",
"when_jamba_model", "when_jamba_model",
"when_jamba1.5_mini_model",
"when_jamba1.5_large_model",
], ],
argnames=["model"], argnames=["model"],
argvalues=[ argvalues=[
(J2_CHAT_MODEL_NAME,), (J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,), (JAMBA_CHAT_MODEL_NAME,),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
], ],
) )
async def test_ageneration(model: str) -> None: async def test_ageneration(model: str) -> None:
@ -92,7 +121,7 @@ async def test_ageneration(model: str) -> None:
def test__chat_stream() -> None: def test__chat_stream() -> None:
llm = ChatAI21(model="jamba-instruct") # type: ignore[call-arg] llm = ChatAI21(model="jamba-1.5-mini") # type: ignore[call-arg]
message = HumanMessage(content="What is the meaning of life?") message = HumanMessage(content="What is the meaning of life?")
for chunk in llm.stream([message]): for chunk in llm.stream([message]):
@ -107,3 +136,53 @@ def test__j2_chat_stream__should_raise_error() -> None:
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
for _ in llm.stream([message]): for _ in llm.stream([message]):
pass pass
@pytest.mark.parametrize(
ids=[
"when_jamba1.5_mini_model",
"when_jamba1.5_large_model",
],
argnames=["model"],
argvalues=[
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
def test_tool_calls(model: str) -> None:
@tool
def get_weather(location: str, date: str) -> str:
"""“Provide the weather for the specified location on the given date.”"""
if location == "New York" and date == "2024-12-05":
return "25 celsius"
return "32 celsius"
llm = ChatAI21(model=model, temperature=0) # type: ignore[call-arg]
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
chat_messages = [
SystemMessage(
content="You are a helpful assistant. "
"You can use the provided tools "
"to assist with various tasks and provide "
"accurate information"
),
HumanMessage(
content="What is the forecast for the weather "
"in New York on December 5, 2024?"
),
]
response = llm_with_tools.invoke(chat_messages)
chat_messages.append(response)
assert response.tool_calls is not None # type: ignore[attr-defined]
tool_call = response.tool_calls[0] # type: ignore[attr-defined]
assert tool_call["name"] == "get_weather"
weather = get_weather.invoke( # type: ignore[attr-defined]
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]}
)
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
llm_answer = llm_with_tools.invoke(chat_messages)
content = llm_answer.content.lower() # type: ignore[union-attr]
assert "new york" in content and "25" in content and "celsius" in content

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests""" """Standard LangChain interface tests"""
import time import time
from typing import Type from typing import Optional, Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
@ -28,6 +28,8 @@ class BaseTestAI21(ChatModelIntegrationTests):
class TestAI21J2(BaseTestAI21): class TestAI21J2(BaseTestAI21):
has_tool_calling = False
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
@ -49,8 +51,23 @@ class TestAI21J2(BaseTestAI21):
class TestAI21Jamba(BaseTestAI21): class TestAI21Jamba(BaseTestAI21):
has_tool_calling = False
@property @property
def chat_model_params(self) -> dict: def chat_model_params(self) -> dict:
return { return {
"model": "jamba-instruct-preview", "model": "jamba-instruct-preview",
} }
class TestAI21Jamba1_5(BaseTestAI21):
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
@property
def chat_model_params(self) -> dict:
return {
"model": "jamba-1.5-mini",
}

View File

@ -3,12 +3,21 @@ from typing import List
import pytest import pytest
from ai21.models import ChatMessage as J2ChatMessage from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType from ai21.models import RoleType
from ai21.models.chat import ChatMessage from ai21.models.chat import (
AssistantMessage,
ChatMessage,
UserMessage,
)
from ai21.models.chat import (
SystemMessage as AI21SystemMessage,
)
from ai21.models.chat import ToolMessage as AI21ToolMessage
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
ToolMessage,
) )
from langchain_core.messages import ( from langchain_core.messages import (
ChatMessage as LangChainChatMessage, ChatMessage as LangChainChatMessage,
@ -18,6 +27,8 @@ from langchain_ai21.chat.chat_adapter import ChatAdapter
_J2_MODEL_NAME = "j2-ultra" _J2_MODEL_NAME = "j2-ultra"
_JAMBA_MODEL_NAME = "jamba-instruct-preview" _JAMBA_MODEL_NAME = "jamba-instruct-preview"
_JAMBA_1_5_MINI_MODEL_NAME = "jamba-1.5-mini"
_JAMBA_1_5_LARGE_MODEL_NAME = "jamba-1.5-large"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -42,12 +53,14 @@ _JAMBA_MODEL_NAME = "jamba-instruct-preview"
( (
_JAMBA_MODEL_NAME, _JAMBA_MODEL_NAME,
HumanMessage(content="Human Message Content"), HumanMessage(content="Human Message Content"),
ChatMessage(role=RoleType.USER, content="Human Message Content"), UserMessage(role="user", content="Human Message Content"),
), ),
( (
_JAMBA_MODEL_NAME, _JAMBA_MODEL_NAME,
AIMessage(content="AI Message Content"), AIMessage(content="AI Message Content"),
ChatMessage(role=RoleType.ASSISTANT, content="AI Message Content"), AssistantMessage(
role="assistant", content="AI Message Content", tool_calls=[]
),
), ),
], ],
) )
@ -69,7 +82,7 @@ def test_convert_message_to_ai21_message(
argvalues=[ argvalues=[
( (
_J2_MODEL_NAME, _J2_MODEL_NAME,
SystemMessage(content="System Message Content"), AI21SystemMessage(content="System Message Content"),
), ),
( (
_J2_MODEL_NAME, _J2_MODEL_NAME,
@ -95,6 +108,8 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
"when_first_message_is_system__should_return_system_j2_model", "when_first_message_is_system__should_return_system_j2_model",
"when_all_messages_are_human_messages__should_return_system_none_jamba_model", "when_all_messages_are_human_messages__should_return_system_none_jamba_model",
"when_first_message_is_system__should_return_system_jamba_model", "when_first_message_is_system__should_return_system_jamba_model",
"when_tool_calling_message__should_return_tool_jamba_mini_model",
"when_tool_calling_message__should_return_tool_jamba_large_model",
], ],
argnames=["model", "messages", "expected_messages"], argnames=["model", "messages", "expected_messages"],
argvalues=[ argvalues=[
@ -142,12 +157,12 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
], ],
{ {
"messages": [ "messages": [
ChatMessage( UserMessage(
role=RoleType.USER, role="user",
content="Human Message Content 1", content="Human Message Content 1",
), ),
ChatMessage( UserMessage(
role=RoleType.USER, role="user",
content="Human Message Content 2", content="Human Message Content 2",
), ),
] ]
@ -161,8 +176,46 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
], ],
{ {
"messages": [ "messages": [
ChatMessage(role="system", content="System Message Content 1"), AI21SystemMessage(
ChatMessage(role="user", content="Human Message Content 1"), role="system", content="System Message Content 1"
),
UserMessage(role="user", content="Human Message Content 1"),
],
},
),
(
_JAMBA_1_5_MINI_MODEL_NAME,
[
ToolMessage(
content="42",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
)
],
{
"messages": [
AI21ToolMessage(
role="tool",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
content="42",
),
],
},
),
(
_JAMBA_1_5_LARGE_MODEL_NAME,
[
ToolMessage(
content="42",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
)
],
{
"messages": [
AI21ToolMessage(
role="tool",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
content="42",
),
], ],
}, },
), ),

View File

@ -23,8 +23,16 @@ from pytest_mock import MockerFixture
J2_CHAT_MODEL_NAME = "j2-ultra" J2_CHAT_MODEL_NAME = "j2-ultra"
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview" JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
JAMBA_1_5_MINI_CHAT_MODEL_NAME = "jamba-1.5-mini"
JAMBA_1_5_LARGE_CHAT_MODEL_NAME = "jamba-1.5-large"
DUMMY_API_KEY = "test_api_key" DUMMY_API_KEY = "test_api_key"
JAMBA_FAMILY_MODEL_NAMES = [
JAMBA_CHAT_MODEL_NAME,
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
]
BASIC_EXAMPLE_LLM_PARAMETERS = { BASIC_EXAMPLE_LLM_PARAMETERS = {
"num_results": 3, "num_results": 3,
"max_tokens": 20, "max_tokens": 20,
@ -32,9 +40,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
"temperature": 0.5, "temperature": 0.5,
"top_p": 0.5, "top_p": 0.5,
"top_k_return": 0, "top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
"count_penalty": Penalty( "count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2, scale=0.2,
apply_to_punctuation=True, apply_to_punctuation=True,
apply_to_emojis=True, apply_to_emojis=True,
@ -48,9 +56,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = {
"temperature": 0.5, "temperature": 0.5,
"top_p": 0.5, "top_p": 0.5,
"top_k_return": 0, "top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
"count_penalty": Penalty( "count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2, scale=0.2,
apply_to_punctuation=True, apply_to_punctuation=True,
apply_to_emojis=True, apply_to_emojis=True,
@ -59,7 +67,7 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = {
} }
SEGMENTS = [ SEGMENTS = [
Segment( Segment( # type: ignore[call-arg]
segment_type="normal_text", segment_type="normal_text",
segment_text=( segment_text=(
"The original full name of the franchise is Pocket Monsters " "The original full name of the franchise is Pocket Monsters "
@ -70,7 +78,7 @@ SEGMENTS = [
"in pronunciation." "in pronunciation."
), ),
), ),
Segment( Segment( # type: ignore[call-arg]
segment_type="normal_text", segment_type="normal_text",
segment_text=( segment_text=(
"Pokémon refers to both the franchise itself and the creatures " "Pokémon refers to both the franchise itself and the creatures "
@ -92,9 +100,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"temperature": 0.5, "temperature": 0.5,
"top_p": 0.5, "top_p": 0.5,
"top_k_return": 0, "top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
"count_penalty": Penalty( "count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2, scale=0.2,
apply_to_punctuation=True, apply_to_punctuation=True,
apply_to_emojis=True, apply_to_emojis=True,
@ -108,9 +116,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
"temperature": 0.5, "temperature": 0.5,
"top_p": 0.5, "top_p": 0.5,
"top_k_return": 0, "top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), "frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), "presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
"count_penalty": Penalty( "count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2, scale=0.2,
apply_to_punctuation=True, apply_to_punctuation=True,
apply_to_emojis=True, apply_to_emojis=True,
@ -124,7 +132,7 @@ def mocked_completion_response(mocker: MockerFixture) -> Mock:
mocked_response = mocker.MagicMock(spec=CompletionsResponse) mocked_response = mocker.MagicMock(spec=CompletionsResponse)
mocked_response.prompt = "this is a test prompt" mocked_response.prompt = "this is a test prompt"
mocked_response.completions = [ mocked_response.completions = [
Completion( Completion( # type: ignore[call-arg]
data=CompletionData(text="test", tokens=[]), data=CompletionData(text="test", tokens=[]),
finish_reason=CompletionFinishReason(reason=None, length=None), finish_reason=CompletionFinishReason(reason=None, length=None),
) )
@ -152,7 +160,7 @@ def mock_client_with_chat(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client) mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.chat = mocker.MagicMock() mock_client.chat = mocker.MagicMock()
output = ChatOutput( output = ChatOutput( # type: ignore[call-arg]
text="Hello Pickle Rick!", text="Hello Pickle Rick!",
role=RoleType.ASSISTANT, role=RoleType.ASSISTANT,
finish_reason=FinishReason(reason="testing"), finish_reason=FinishReason(reason="testing"),
@ -178,7 +186,7 @@ def temporarily_unset_api_key() -> Generator:
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock: def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client) mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.answer = mocker.MagicMock() mock_client.answer = mocker.MagicMock()
mock_client.answer.create.return_value = AnswerResponse( mock_client.answer.create.return_value = AnswerResponse( # type: ignore[call-arg]
id="some_id", id="some_id",
answer="some answer", answer="some answer",
answer_in_context=False, answer_in_context=False,

View File

@ -45,9 +45,9 @@ def test_initialization__when_custom_parameters_in_init() -> None:
temperature = 0.1 temperature = 0.1
top_p = 0.1 top_p = 0.1
top_k_return = 0 top_k_return = 0
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) # type: ignore[call-arg]
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) # type: ignore[call-arg]
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) # type: ignore[call-arg]
llm = ChatAI21( # type: ignore[call-arg] llm = ChatAI21( # type: ignore[call-arg]
api_key=DUMMY_API_KEY, # type: ignore[arg-type] api_key=DUMMY_API_KEY, # type: ignore[arg-type]

View File

@ -17,9 +17,9 @@ _EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse( _EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
results=[ results=[
EmbedResult(_EXAMPLE_EMBEDDING_0), EmbedResult(embedding=_EXAMPLE_EMBEDDING_0),
EmbedResult(_EXAMPLE_EMBEDDING_1), EmbedResult(embedding=_EXAMPLE_EMBEDDING_1),
EmbedResult(_EXAMPLE_EMBEDDING_2), EmbedResult(embedding=_EXAMPLE_EMBEDDING_2),
], ],
id="test_id", id="test_id",
) )

View File

@ -49,9 +49,9 @@ def test_initialization__when_custom_parameters_to_init() -> None:
top_p=0.5, top_p=0.5,
top_k_return=0, top_k_return=0,
stop_sequences=["\n"], stop_sequences=["\n"],
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
count_penalty=Penalty( count_penalty=Penalty( # type: ignore[call-arg]
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
), ),
custom_model="test_model", custom_model="test_model",