mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 02:43:37 +00:00
partners: AI21 Labs Jamba Streaming Support (#23538)
Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** Added support for streaming in AI21 Jamba Model - **Twitter handle:** https://github.com/AI21Labs - [x] **Add tests and docs**: If you're adding a new integration, please include - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
5cd4083457
commit
320dc31822
@ -28,13 +28,40 @@ Then initialize
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
chat = ChatAI21(model="jamba-instruct-preview")
|
||||
chat = ChatAI21(model="jamba-instruct")
|
||||
messages = [HumanMessage(content="Hello from AI21")]
|
||||
chat.invoke(messages)
|
||||
```
|
||||
|
||||
For a list of the supported models, see [this page](https://docs.ai21.com/reference/python-sdk#chat)
|
||||
|
||||
### Streaming in Chat
|
||||
Streaming is supported by the latest models. To use streaming, set the `streaming` parameter to `True` when initializing the model.
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
chat = ChatAI21(model="jamba-instruct", streaming=True)
|
||||
messages = [HumanMessage(content="Hello from AI21")]
|
||||
|
||||
response = chat.invoke(messages)
|
||||
```
|
||||
|
||||
or use the `stream` method directly
|
||||
|
||||
```python
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
chat = ChatAI21(model="jamba-instruct")
|
||||
messages = [HumanMessage(content="Hello from AI21")]
|
||||
|
||||
for chunk in chat.stream(messages):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
|
||||
## LLMs
|
||||
You can use AI21's Jurassic generative AI models as LangChain LLMs.
|
||||
To use the newer Jamba model, use the [ChatAI21 chat model](#chat-models), which
|
||||
|
@ -1,12 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Union, cast
|
||||
from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload
|
||||
|
||||
from ai21.models import ChatMessage as J2ChatMessage
|
||||
from ai21.models import RoleType
|
||||
from ai21.models.chat import ChatMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from ai21.models.chat import ChatCompletionChunk, ChatMessage
|
||||
from ai21.stream.stream import Stream as AI21Stream
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
_ChatMessageTypes = Union[ChatMessage, J2ChatMessage]
|
||||
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
|
||||
@ -63,8 +72,31 @@ class ChatAdapter(ABC):
|
||||
) -> _ChatMessageTypes:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True],
|
||||
**params: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True] | Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
|
||||
pass
|
||||
|
||||
def _get_system_message_from_message(self, message: BaseMessage) -> str:
|
||||
@ -102,7 +134,33 @@ class J2ChatAdapter(ChatAdapter):
|
||||
) -> J2ChatMessage:
|
||||
return J2ChatMessage(role=RoleType(role), text=content)
|
||||
|
||||
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True],
|
||||
**params: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage]:
|
||||
...
|
||||
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True] | Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
|
||||
if stream:
|
||||
raise NotImplementedError("Streaming is not supported for Jurassic models.")
|
||||
|
||||
response = client.chat.create(**params)
|
||||
|
||||
return [AIMessage(output.text) for output in response.outputs]
|
||||
@ -128,7 +186,62 @@ class JambaChatCompletionsAdapter(ChatAdapter):
|
||||
content=content,
|
||||
)
|
||||
|
||||
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
|
||||
response = client.chat.completions.create(**params)
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True],
|
||||
**params: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage]:
|
||||
...
|
||||
|
||||
def call(
|
||||
self,
|
||||
client: Any,
|
||||
stream: Literal[True] | Literal[False],
|
||||
**params: Any,
|
||||
) -> List[BaseMessage] | Iterator[ChatGenerationChunk]:
|
||||
response = client.chat.completions.create(stream=stream, **params)
|
||||
|
||||
if stream:
|
||||
return self._stream_response(response)
|
||||
|
||||
return [AIMessage(choice.message.content) for choice in response.choices]
|
||||
|
||||
def _stream_response(
|
||||
self,
|
||||
response: AI21Stream[ChatCompletionChunk],
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
for chunk in response:
|
||||
converted_message = self._convert_ai21_chunk_to_chunk(chunk)
|
||||
yield ChatGenerationChunk(message=converted_message)
|
||||
|
||||
def _convert_ai21_chunk_to_chunk(
|
||||
self,
|
||||
chunk: ChatCompletionChunk,
|
||||
) -> BaseMessageChunk:
|
||||
usage = chunk.usage
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
|
||||
if usage is None:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
)
|
||||
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
usage_metadata=UsageMetadata(
|
||||
input_tokens=usage.prompt_tokens,
|
||||
output_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
),
|
||||
)
|
||||
|
@ -1,16 +1,20 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
@ -75,6 +79,7 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
streaming: bool = False
|
||||
|
||||
_chat_adapter: ChatAdapter
|
||||
|
||||
@ -166,14 +171,65 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
|
||||
should_stream = stream or self.streaming
|
||||
|
||||
if should_stream:
|
||||
return self._handle_stream_from_generate(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
params = self._build_params_for_request(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
stream=should_stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
messages = self._chat_adapter.call(self.client, **params)
|
||||
generations = [ChatGeneration(message=message) for message in messages]
|
||||
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
def _handle_stream_from_generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
stream_iter = self._stream(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._build_params_for_request(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for chunk in self._chat_adapter.call(self.client, **params):
|
||||
if run_manager and isinstance(chunk.message.content, str):
|
||||
run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
31
libs/partners/ai21/poetry.lock
generated
31
libs/partners/ai21/poetry.lock
generated
@ -2,17 +2,17 @@
|
||||
|
||||
[[package]]
|
||||
name = "ai21"
|
||||
version = "2.4.1"
|
||||
version = "2.7.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "ai21-2.4.1-py3-none-any.whl", hash = "sha256:2d5e341d01343c8e61598f374b37cabb67ba5857abe128cec97c2ba870403def"},
|
||||
{file = "ai21-2.4.1.tar.gz", hash = "sha256:55380eee2d7dae6973522ea68f4d39641cb6058337aaa07746811e3d07e4bd73"},
|
||||
{file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"},
|
||||
{file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
ai21-tokenizer = ">=0.9.1,<1.0.0"
|
||||
ai21-tokenizer = ">=0.11.0,<1.0.0"
|
||||
dataclasses-json = ">=0.6.3,<0.7.0"
|
||||
httpx = ">=0.27.0,<0.28.0"
|
||||
tenacity = ">=8.3.0,<9.0.0"
|
||||
@ -23,16 +23,17 @@ aws = ["boto3 (>=1.28.82,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "ai21-tokenizer"
|
||||
version = "0.9.1"
|
||||
version = "0.11.2"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.7"
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "ai21_tokenizer-0.9.1-py3-none-any.whl", hash = "sha256:53e85ddd74b375e7384e5b5c87ff60e63f60e8615c82ce80f27a062b530798ca"},
|
||||
{file = "ai21_tokenizer-0.9.1.tar.gz", hash = "sha256:28c72122ef1eda6ba03c8ae55916de74a7ee9504c759e9bbae44abe821955df9"},
|
||||
{file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"},
|
||||
{file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.4.0,<5.0.0"
|
||||
sentencepiece = ">=0.2.0,<1.0.0"
|
||||
tokenizers = ">=0.15.0,<1.0.0"
|
||||
|
||||
@ -505,18 +506,21 @@ url = "../../text-splitters"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.79"
|
||||
version = "0.1.82"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.79-py3-none-any.whl", hash = "sha256:c7f2c23981917713b5515b773f37c84ff68a7adf803476e2ebb5adcb36a04202"},
|
||||
{file = "langsmith-0.1.79.tar.gz", hash = "sha256:d215718cfdcdf4a011126b7a3d4a37eee96d887e59ac1e628a57e24b2bfa3163"},
|
||||
{file = "langsmith-0.1.82-py3-none-any.whl", hash = "sha256:9b3653e7d316036b0c60bf0bc3e280662d660f485a4ebd8e5c9d84f9831ae79c"},
|
||||
{file = "langsmith-0.1.82.tar.gz", hash = "sha256:c02e2bbc488c10c13b52c69d271eb40bd38da078d37b6ae7ae04a18bd48140be"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
orjson = ">=3.9.14,<4.0.0"
|
||||
pydantic = ">=1,<3"
|
||||
pydantic = [
|
||||
{version = ">=1,<3", markers = "python_full_version < \"3.12.4\""},
|
||||
{version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""},
|
||||
]
|
||||
requests = ">=2,<3"
|
||||
|
||||
[[package]]
|
||||
@ -901,7 +905,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -1334,4 +1337,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "5817da5e2d48e6cc0a625ac6dea7336bc8088a3de7bea588107b143c253f4bcf"
|
||||
content-hash = "610cb0bc2d9093fa742c3fdc6f839a0560a619d1fab1103b44f8155f31ab4045"
|
||||
|
@ -14,7 +14,7 @@ license = "MIT"
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.2.4"
|
||||
langchain-text-splitters = "^0.2.0"
|
||||
ai21 = "^2.4.1"
|
||||
ai21 = "^2.7.0"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Test ChatAI21 chat model."""
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
@ -85,3 +85,21 @@ async def test_ageneration(model: str) -> None:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
def test__chat_stream() -> None:
|
||||
llm = ChatAI21(model="jamba-instruct")
|
||||
message = HumanMessage(content="What is the meaning of life?")
|
||||
|
||||
for chunk in llm.stream([message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
assert isinstance(chunk.content, str)
|
||||
|
||||
|
||||
def test__j2_chat_stream__should_raise_error() -> None:
|
||||
llm = ChatAI21(model="j2-ultra")
|
||||
message = HumanMessage(content="What is the meaning of life?")
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
for _ in llm.stream([message]):
|
||||
pass
|
||||
|
@ -19,14 +19,6 @@ class BaseTestAI21(ChatModelIntegrationTests):
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatAI21
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
super().test_stream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
await super().test_astream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Not implemented.")
|
||||
def test_usage_metadata(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata(model)
|
||||
@ -39,6 +31,14 @@ class TestAI21J2(BaseTestAI21):
|
||||
"model": "j2-ultra",
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
def test_stream(self, model: BaseChatModel) -> None:
|
||||
super().test_stream(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
|
||||
async def test_astream(self, model: BaseChatModel) -> None:
|
||||
await super().test_astream(model)
|
||||
|
||||
|
||||
class TestAI21Jamba(BaseTestAI21):
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user