partners: AI21 Labs Jamba Streaming Support (#23538)

Thank you for contributing to LangChain!

- [x] **PR title**: "package: description"

- [x] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** Added support for streaming in AI21 Jamba Model
    - **Twitter handle:** https://github.com/AI21Labs


- [x] **Add tests and docs**: If you're adding a new integration, please
include

- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Asaf Joseph Gardin 2024-07-03 00:15:46 +03:00 committed by GitHub
parent 5cd4083457
commit 320dc31822
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 253 additions and 36 deletions

View File

@ -28,13 +28,40 @@ Then initialize
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21
chat = ChatAI21(model="jamba-instruct-preview")
chat = ChatAI21(model="jamba-instruct")
messages = [HumanMessage(content="Hello from AI21")]
chat.invoke(messages)
```
For a list of the supported models, see [this page](https://docs.ai21.com/reference/python-sdk#chat)
### Streaming in Chat
Streaming is supported by the latest models. To use streaming, set the `streaming` parameter to `True` when initializing the model.
```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21
chat = ChatAI21(model="jamba-instruct", streaming=True)
messages = [HumanMessage(content="Hello from AI21")]
response = chat.invoke(messages)
```
or use the `stream` method directly
```python
from langchain_core.messages import HumanMessage
from langchain_ai21.chat_models import ChatAI21
chat = ChatAI21(model="jamba-instruct")
messages = [HumanMessage(content="Hello from AI21")]
for chunk in chat.stream(messages):
print(chunk)
```
## LLMs
You can use AI21's Jurassic generative AI models as LangChain LLMs.
To use the newer Jamba model, use the [ChatAI21 chat model](#chat-models), which

View File

@ -1,12 +1,21 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union, cast
from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import ChatMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from ai21.models.chat import ChatCompletionChunk, ChatMessage
from ai21.stream.stream import Stream as AI21Stream
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGenerationChunk
_ChatMessageTypes = Union[ChatMessage, J2ChatMessage]
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
@ -63,8 +72,31 @@ class ChatAdapter(ABC):
) -> _ChatMessageTypes:
pass
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
pass
@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
pass
@abstractmethod
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
pass
def _get_system_message_from_message(self, message: BaseMessage) -> str:
@ -102,7 +134,33 @@ class J2ChatAdapter(ChatAdapter):
) -> J2ChatMessage:
return J2ChatMessage(role=RoleType(role), text=content)
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
...
@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
...
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
if stream:
raise NotImplementedError("Streaming is not supported for Jurassic models.")
response = client.chat.create(**params)
return [AIMessage(output.text) for output in response.outputs]
@ -128,7 +186,62 @@ class JambaChatCompletionsAdapter(ChatAdapter):
content=content,
)
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
response = client.chat.completions.create(**params)
@overload
def call(
self,
client: Any,
stream: Literal[True],
**params: Any,
) -> Iterator[ChatGenerationChunk]:
...
@overload
def call(
self,
client: Any,
stream: Literal[False],
**params: Any,
) -> List[BaseMessage]:
...
def call(
self,
client: Any,
stream: Literal[True] | Literal[False],
**params: Any,
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
response = client.chat.completions.create(stream=stream, **params)
if stream:
return self._stream_response(response)
return [AIMessage(choice.message.content) for choice in response.choices]
def _stream_response(
self,
response: AI21Stream[ChatCompletionChunk],
) -> Iterator[ChatGenerationChunk]:
for chunk in response:
converted_message = self._convert_ai21_chunk_to_chunk(chunk)
yield ChatGenerationChunk(message=converted_message)
def _convert_ai21_chunk_to_chunk(
self,
chunk: ChatCompletionChunk,
) -> BaseMessageChunk:
usage = chunk.usage
content = chunk.choices[0].delta.content or ""
if usage is None:
return AIMessageChunk(
content=content,
)
return AIMessageChunk(
content=content,
usage_metadata=UsageMetadata(
input_tokens=usage.prompt_tokens,
output_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
),
)

View File

@ -1,16 +1,20 @@
import asyncio
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
generate_from_stream,
)
from langchain_core.messages import (
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_ai21.ai21_base import AI21Base
@ -75,6 +79,7 @@ class ChatAI21(BaseChatModel, AI21Base):
n: int = 1
"""Number of chat completions to generate for each prompt."""
streaming: bool = False
_chat_adapter: ChatAdapter
@ -166,14 +171,65 @@ class ChatAI21(BaseChatModel, AI21Base):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
should_stream = stream or self.streaming
if should_stream:
return self._handle_stream_from_generate(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
params = self._build_params_for_request(
messages=messages,
stop=stop,
stream=should_stream,
**kwargs,
)
messages = self._chat_adapter.call(self.client, **params)
generations = [ChatGeneration(message=message) for message in messages]
return ChatResult(generations=generations)
def _handle_stream_from_generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
stream_iter = self._stream(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
return generate_from_stream(stream_iter)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._build_params_for_request(
messages=messages,
stop=stop,
stream=True,
**kwargs,
)
for chunk in self._chat_adapter.call(self.client, **params):
if run_manager and isinstance(chunk.message.content, str):
run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
yield chunk
async def _agenerate(
self,
messages: List[BaseMessage],

View File

@ -2,17 +2,17 @@
[[package]]
name = "ai21"
version = "2.4.1"
version = "2.7.0"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "ai21-2.4.1-py3-none-any.whl", hash = "sha256:2d5e341d01343c8e61598f374b37cabb67ba5857abe128cec97c2ba870403def"},
{file = "ai21-2.4.1.tar.gz", hash = "sha256:55380eee2d7dae6973522ea68f4d39641cb6058337aaa07746811e3d07e4bd73"},
{file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"},
{file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"},
]
[package.dependencies]
ai21-tokenizer = ">=0.9.1,<1.0.0"
ai21-tokenizer = ">=0.11.0,<1.0.0"
dataclasses-json = ">=0.6.3,<0.7.0"
httpx = ">=0.27.0,<0.28.0"
tenacity = ">=8.3.0,<9.0.0"
@ -23,16 +23,17 @@ aws = ["boto3 (>=1.28.82,<2.0.0)"]
[[package]]
name = "ai21-tokenizer"
version = "0.9.1"
version = "0.11.2"
description = ""
optional = false
python-versions = "<4.0,>=3.7"
python-versions = "<4.0,>=3.8"
files = [
{file = "ai21_tokenizer-0.9.1-py3-none-any.whl", hash = "sha256:53e85ddd74b375e7384e5b5c87ff60e63f60e8615c82ce80f27a062b530798ca"},
{file = "ai21_tokenizer-0.9.1.tar.gz", hash = "sha256:28c72122ef1eda6ba03c8ae55916de74a7ee9504c759e9bbae44abe821955df9"},
{file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"},
{file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"},
]
[package.dependencies]
anyio = ">=4.4.0,<5.0.0"
sentencepiece = ">=0.2.0,<1.0.0"
tokenizers = ">=0.15.0,<1.0.0"
@ -505,18 +506,21 @@ url = "../../text-splitters"
[[package]]
name = "langsmith"
version = "0.1.79"
version = "0.1.82"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.79-py3-none-any.whl", hash = "sha256:c7f2c23981917713b5515b773f37c84ff68a7adf803476e2ebb5adcb36a04202"},
{file = "langsmith-0.1.79.tar.gz", hash = "sha256:d215718cfdcdf4a011126b7a3d4a37eee96d887e59ac1e628a57e24b2bfa3163"},
{file = "langsmith-0.1.82-py3-none-any.whl", hash = "sha256:9b3653e7d316036b0c60bf0bc3e280662d660f485a4ebd8e5c9d84f9831ae79c"},
{file = "langsmith-0.1.82.tar.gz", hash = "sha256:c02e2bbc488c10c13b52c69d271eb40bd38da078d37b6ae7ae04a18bd48140be"},
]
[package.dependencies]
orjson = ">=3.9.14,<4.0.0"
pydantic = ">=1,<3"
pydantic = [
{version = ">=1,<3", markers = "python_full_version < \"3.12.4\""},
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
]
requests = ">=2,<3"
[[package]]
@ -901,7 +905,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1334,4 +1337,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "5817da5e2d48e6cc0a625ac6dea7336bc8088a3de7bea588107b143c253f4bcf"
content-hash = "610cb0bc2d9093fa742c3fdc6f839a0560a619d1fab1103b44f8155f31ab4045"

View File

@ -14,7 +14,7 @@ license = "MIT"
python = ">=3.8.1,<4.0"
langchain-core = "^0.2.4"
langchain-text-splitters = "^0.2.0"
ai21 = "^2.4.1"
ai21 = "^2.7.0"
[tool.poetry.group.test]
optional = true

View File

@ -1,6 +1,6 @@
"""Test ChatAI21 chat model."""
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_ai21.chat_models import ChatAI21
@ -85,3 +85,21 @@ async def test_ageneration(model: str) -> None:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test__chat_stream() -> None:
llm = ChatAI21(model="jamba-instruct")
message = HumanMessage(content="What is the meaning of life?")
for chunk in llm.stream([message]):
assert isinstance(chunk, AIMessageChunk)
assert isinstance(chunk.content, str)
def test__j2_chat_stream__should_raise_error() -> None:
llm = ChatAI21(model="j2-ultra")
message = HumanMessage(content="What is the meaning of life?")
with pytest.raises(NotImplementedError):
for _ in llm.stream([message]):
pass

View File

@ -19,14 +19,6 @@ class BaseTestAI21(ChatModelIntegrationTests):
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
def test_stream(self, model: BaseChatModel) -> None:
super().test_stream(model)
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
async def test_astream(self, model: BaseChatModel) -> None:
await super().test_astream(model)
@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(self, model: BaseChatModel) -> None:
super().test_usage_metadata(model)
@ -39,6 +31,14 @@ class TestAI21J2(BaseTestAI21):
"model": "j2-ultra",
}
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
def test_stream(self, model: BaseChatModel) -> None:
super().test_stream(model)
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
async def test_astream(self, model: BaseChatModel) -> None:
await super().test_astream(model)
class TestAI21Jamba(BaseTestAI21):
@property