mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 19:18:53 +00:00
anthropic[patch]: streaming param (#18819)
This commit is contained in:
parent
8c0b215c02
commit
a5bcddc738
@ -8,7 +8,11 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -174,6 +178,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to use streaming or not."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
@ -271,6 +278,11 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
data = self._client.messages.create(**params)
|
||||
return self._format_output(data)
|
||||
@ -282,6 +294,11 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
data = await self._async_client.messages.create(**params)
|
||||
return self._format_output(data)
|
||||
|
@ -184,3 +184,31 @@ def test_anthropic_multimodal() -> None:
|
||||
response = chat.invoke(messages)
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_streaming() -> None:
|
||||
"""Test streaming tokens from Anthropic."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
llm = ChatAnthropicMessages(
|
||||
model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager
|
||||
)
|
||||
|
||||
response = llm.generate([[HumanMessage(content="I'm Pickle Rick")]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
|
||||
|
||||
async def test_astreaming() -> None:
|
||||
"""Test streaming tokens from Anthropic."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
llm = ChatAnthropicMessages(
|
||||
model_name=MODEL_NAME, streaming=True, callback_manager=callback_manager
|
||||
)
|
||||
|
||||
response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
|
Loading…
Reference in New Issue
Block a user