diff --git a/libs/partners/ai21/README.md b/libs/partners/ai21/README.md index 524854c7c7a..df66bc72111 100644 --- a/libs/partners/ai21/README.md +++ b/libs/partners/ai21/README.md @@ -28,13 +28,40 @@ Then initialize from langchain_core.messages import HumanMessage from langchain_ai21.chat_models import ChatAI21 -chat = ChatAI21(model="jamba-instruct-preview") +chat = ChatAI21(model="jamba-instruct") messages = [HumanMessage(content="Hello from AI21")] chat.invoke(messages) ``` For a list of the supported models, see [this page](https://docs.ai21.com/reference/python-sdk#chat) +### Streaming in Chat +Streaming is supported by the latest models. To use streaming, set the `streaming` parameter to `True` when initializing the model. + +```python +from langchain_core.messages import HumanMessage +from langchain_ai21.chat_models import ChatAI21 + +chat = ChatAI21(model="jamba-instruct", streaming=True) +messages = [HumanMessage(content="Hello from AI21")] + +response = chat.invoke(messages) +``` + +or use the `stream` method directly + +```python +from langchain_core.messages import HumanMessage +from langchain_ai21.chat_models import ChatAI21 + +chat = ChatAI21(model="jamba-instruct") +messages = [HumanMessage(content="Hello from AI21")] + +for chunk in chat.stream(messages): + print(chunk) +``` + + ## LLMs You can use AI21's Jurassic generative AI models as LangChain LLMs. To use the newer Jamba model, use the [ChatAI21 chat model](#chat-models), which diff --git a/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py b/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py index 67f70ca381b..a7441e17ba1 100644 --- a/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py +++ b/libs/partners/ai21/langchain_ai21/chat/chat_adapter.py @@ -1,12 +1,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload from ai21.models import ChatMessage as J2ChatMessage from ai21.models import RoleType -from ai21.models.chat import ChatMessage -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from ai21.models.chat import ChatCompletionChunk, ChatMessage +from ai21.stream.stream import Stream as AI21Stream +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + HumanMessage, +) +from langchain_core.messages.ai import UsageMetadata +from langchain_core.outputs import ChatGenerationChunk _ChatMessageTypes = Union[ChatMessage, J2ChatMessage] _SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list." @@ -63,8 +72,31 @@ class ChatAdapter(ABC): ) -> _ChatMessageTypes: pass + @overload + def call( + self, + client: Any, + stream: Literal[True], + **params: Any, + ) -> Iterator[ChatGenerationChunk]: + pass + + @overload + def call( + self, + client: Any, + stream: Literal[False], + **params: Any, + ) -> List[BaseMessage]: + pass + @abstractmethod - def call(self, client: Any, **params: Any) -> List[BaseMessage]: + def call( + self, + client: Any, + stream: Literal[True] | Literal[False], + **params: Any, + ) -> List[BaseMessage] | Iterator[ChatGenerationChunk]: pass def _get_system_message_from_message(self, message: BaseMessage) -> str: @@ -102,7 +134,33 @@ class J2ChatAdapter(ChatAdapter): ) -> J2ChatMessage: return J2ChatMessage(role=RoleType(role), text=content) - def call(self, client: Any, **params: Any) -> List[BaseMessage]: + @overload + def call( + self, + client: Any, + stream: Literal[True], + **params: Any, + ) -> Iterator[ChatGenerationChunk]: + ... + + @overload + def call( + self, + client: Any, + stream: Literal[False], + **params: Any, + ) -> List[BaseMessage]: + ... + + def call( + self, + client: Any, + stream: Literal[True] | Literal[False], + **params: Any, + ) -> List[BaseMessage] | Iterator[ChatGenerationChunk]: + if stream: + raise NotImplementedError("Streaming is not supported for Jurassic models.") + response = client.chat.create(**params) return [AIMessage(output.text) for output in response.outputs] @@ -128,7 +186,62 @@ class JambaChatCompletionsAdapter(ChatAdapter): content=content, ) - def call(self, client: Any, **params: Any) -> List[BaseMessage]: - response = client.chat.completions.create(**params) + @overload + def call( + self, + client: Any, + stream: Literal[True], + **params: Any, + ) -> Iterator[ChatGenerationChunk]: + ... + + @overload + def call( + self, + client: Any, + stream: Literal[False], + **params: Any, + ) -> List[BaseMessage]: + ... + + def call( + self, + client: Any, + stream: Literal[True] | Literal[False], + **params: Any, + ) -> List[BaseMessage] | Iterator[ChatGenerationChunk]: + response = client.chat.completions.create(stream=stream, **params) + + if stream: + return self._stream_response(response) return [AIMessage(choice.message.content) for choice in response.choices] + + def _stream_response( + self, + response: AI21Stream[ChatCompletionChunk], + ) -> Iterator[ChatGenerationChunk]: + for chunk in response: + converted_message = self._convert_ai21_chunk_to_chunk(chunk) + yield ChatGenerationChunk(message=converted_message) + + def _convert_ai21_chunk_to_chunk( + self, + chunk: ChatCompletionChunk, + ) -> BaseMessageChunk: + usage = chunk.usage + content = chunk.choices[0].delta.content or "" + + if usage is None: + return AIMessageChunk( + content=content, + ) + + return AIMessageChunk( + content=content, + usage_metadata=UsageMetadata( + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + ), + ) diff --git a/libs/partners/ai21/langchain_ai21/chat_models.py b/libs/partners/ai21/langchain_ai21/chat_models.py index 9150ab4d883..4e31f66ff0f 100644 --- a/libs/partners/ai21/langchain_ai21/chat_models.py +++ b/libs/partners/ai21/langchain_ai21/chat_models.py @@ -1,16 +1,20 @@ import asyncio from functools import partial -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams +from langchain_core.language_models.chat_models import ( + BaseChatModel, + LangSmithParams, + generate_from_stream, +) from langchain_core.messages import ( BaseMessage, ) -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import root_validator from langchain_ai21.ai21_base import AI21Base @@ -75,6 +79,7 @@ class ChatAI21(BaseChatModel, AI21Base): n: int = 1 """Number of chat completions to generate for each prompt.""" + streaming: bool = False _chat_adapter: ChatAdapter @@ -166,14 +171,65 @@ class ChatAI21(BaseChatModel, AI21Base): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - params = self._build_params_for_request(messages=messages, stop=stop, **kwargs) + should_stream = stream or self.streaming + + if should_stream: + return self._handle_stream_from_generate( + messages=messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + + params = self._build_params_for_request( + messages=messages, + stop=stop, + stream=should_stream, + **kwargs, + ) + messages = self._chat_adapter.call(self.client, **params) generations = [ChatGeneration(message=message) for message in messages] return ChatResult(generations=generations) + def _handle_stream_from_generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + stream_iter = self._stream( + messages=messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + return generate_from_stream(stream_iter) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + params = self._build_params_for_request( + messages=messages, + stop=stop, + stream=True, + **kwargs, + ) + + for chunk in self._chat_adapter.call(self.client, **params): + if run_manager and isinstance(chunk.message.content, str): + run_manager.on_llm_new_token(token=chunk.message.content, chunk=chunk) + yield chunk + async def _agenerate( self, messages: List[BaseMessage], diff --git a/libs/partners/ai21/poetry.lock b/libs/partners/ai21/poetry.lock index c19e17fb36a..39369ee3672 100644 --- a/libs/partners/ai21/poetry.lock +++ b/libs/partners/ai21/poetry.lock @@ -2,17 +2,17 @@ [[package]] name = "ai21" -version = "2.4.1" +version = "2.7.0" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "ai21-2.4.1-py3-none-any.whl", hash = "sha256:2d5e341d01343c8e61598f374b37cabb67ba5857abe128cec97c2ba870403def"}, - {file = "ai21-2.4.1.tar.gz", hash = "sha256:55380eee2d7dae6973522ea68f4d39641cb6058337aaa07746811e3d07e4bd73"}, + {file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"}, + {file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"}, ] [package.dependencies] -ai21-tokenizer = ">=0.9.1,<1.0.0" +ai21-tokenizer = ">=0.11.0,<1.0.0" dataclasses-json = ">=0.6.3,<0.7.0" httpx = ">=0.27.0,<0.28.0" tenacity = ">=8.3.0,<9.0.0" @@ -23,16 +23,17 @@ aws = ["boto3 (>=1.28.82,<2.0.0)"] [[package]] name = "ai21-tokenizer" -version = "0.9.1" +version = "0.11.2" description = "" optional = false -python-versions = "<4.0,>=3.7" +python-versions = "<4.0,>=3.8" files = [ - {file = "ai21_tokenizer-0.9.1-py3-none-any.whl", hash = "sha256:53e85ddd74b375e7384e5b5c87ff60e63f60e8615c82ce80f27a062b530798ca"}, - {file = "ai21_tokenizer-0.9.1.tar.gz", hash = "sha256:28c72122ef1eda6ba03c8ae55916de74a7ee9504c759e9bbae44abe821955df9"}, + {file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"}, + {file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"}, ] [package.dependencies] +anyio = ">=4.4.0,<5.0.0" sentencepiece = ">=0.2.0,<1.0.0" tokenizers = ">=0.15.0,<1.0.0" @@ -505,18 +506,21 @@ url = "../../text-splitters" [[package]] name = "langsmith" -version = "0.1.79" +version = "0.1.82" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.79-py3-none-any.whl", hash = "sha256:c7f2c23981917713b5515b773f37c84ff68a7adf803476e2ebb5adcb36a04202"}, - {file = "langsmith-0.1.79.tar.gz", hash = "sha256:d215718cfdcdf4a011126b7a3d4a37eee96d887e59ac1e628a57e24b2bfa3163"}, + {file = "langsmith-0.1.82-py3-none-any.whl", hash = "sha256:9b3653e7d316036b0c60bf0bc3e280662d660f485a4ebd8e5c9d84f9831ae79c"}, + {file = "langsmith-0.1.82.tar.gz", hash = "sha256:c02e2bbc488c10c13b52c69d271eb40bd38da078d37b6ae7ae04a18bd48140be"}, ] [package.dependencies] orjson = ">=3.9.14,<4.0.0" -pydantic = ">=1,<3" +pydantic = [ + {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, +] requests = ">=2,<3" [[package]] @@ -901,7 +905,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1334,4 +1337,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "5817da5e2d48e6cc0a625ac6dea7336bc8088a3de7bea588107b143c253f4bcf" +content-hash = "610cb0bc2d9093fa742c3fdc6f839a0560a619d1fab1103b44f8155f31ab4045" diff --git a/libs/partners/ai21/pyproject.toml b/libs/partners/ai21/pyproject.toml index cc08d75a78a..c054b2f9910 100644 --- a/libs/partners/ai21/pyproject.toml +++ b/libs/partners/ai21/pyproject.toml @@ -14,7 +14,7 @@ license = "MIT" python = ">=3.8.1,<4.0" langchain-core = "^0.2.4" langchain-text-splitters = "^0.2.0" -ai21 = "^2.4.1" +ai21 = "^2.7.0" [tool.poetry.group.test] optional = true diff --git a/libs/partners/ai21/tests/integration_tests/test_chat_models.py b/libs/partners/ai21/tests/integration_tests/test_chat_models.py index d83ff41d4af..1de6b915f5f 100644 --- a/libs/partners/ai21/tests/integration_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/integration_tests/test_chat_models.py @@ -1,6 +1,6 @@ """Test ChatAI21 chat model.""" import pytest -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessageChunk, HumanMessage from langchain_core.outputs import ChatGeneration from langchain_ai21.chat_models import ChatAI21 @@ -85,3 +85,21 @@ async def test_ageneration(model: str) -> None: assert isinstance(generation, ChatGeneration) assert isinstance(generation.text, str) assert generation.text == generation.message.content + + +def test__chat_stream() -> None: + llm = ChatAI21(model="jamba-instruct") + message = HumanMessage(content="What is the meaning of life?") + + for chunk in llm.stream([message]): + assert isinstance(chunk, AIMessageChunk) + assert isinstance(chunk.content, str) + + +def test__j2_chat_stream__should_raise_error() -> None: + llm = ChatAI21(model="j2-ultra") + message = HumanMessage(content="What is the meaning of life?") + + with pytest.raises(NotImplementedError): + for _ in llm.stream([message]): + pass diff --git a/libs/partners/ai21/tests/integration_tests/test_standard.py b/libs/partners/ai21/tests/integration_tests/test_standard.py index 2235dcd3212..fca99fd2b91 100644 --- a/libs/partners/ai21/tests/integration_tests/test_standard.py +++ b/libs/partners/ai21/tests/integration_tests/test_standard.py @@ -19,14 +19,6 @@ class BaseTestAI21(ChatModelIntegrationTests): def chat_model_class(self) -> Type[BaseChatModel]: return ChatAI21 - @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") - def test_stream(self, model: BaseChatModel) -> None: - super().test_stream(model) - - @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") - async def test_astream(self, model: BaseChatModel) -> None: - await super().test_astream(model) - @pytest.mark.xfail(reason="Not implemented.") def test_usage_metadata(self, model: BaseChatModel) -> None: super().test_usage_metadata(model) @@ -39,6 +31,14 @@ class TestAI21J2(BaseTestAI21): "model": "j2-ultra", } + @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") + def test_stream(self, model: BaseChatModel) -> None: + super().test_stream(model) + + @pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.") + async def test_astream(self, model: BaseChatModel) -> None: + await super().test_astream(model) + class TestAI21Jamba(BaseTestAI21): @property