mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
core[minor]: Chat Models to fallback astream to fallback on sync stream if available (#18748)
Allows all chat models that implement _stream, but not _astream to still have async streaming to work. Amongst other things this should resolve issues with streaming community model implementations through langserve since langserve is exclusively async.
This commit is contained in:
parent
3624f56ccb
commit
cdfb5b4ca1
@ -7,7 +7,9 @@ from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@ -100,6 +102,26 @@ async def agenerate_from_stream(
|
||||
)
|
||||
|
||||
|
||||
def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
||||
"""Convert a sync iterator into an async iterator."""
|
||||
|
||||
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(
|
||||
None,
|
||||
next,
|
||||
iterator,
|
||||
done, # type: ignore[call-arg, arg-type]
|
||||
)
|
||||
if item is done:
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
return _as_sync_iterator
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
@ -259,57 +281,74 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream == BaseChatModel._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
if type(self)._astream is not BaseChatModel._astream:
|
||||
# Then astream is implemented
|
||||
_stream_implementation = self._astream
|
||||
elif type(self)._stream is not BaseChatModel._stream:
|
||||
# Then stream is implemented, so we can create an async iterator from it
|
||||
# The typing is hard to type correctly with mypy here, so we cast
|
||||
# and do a type ignore, this code is unit tested and should be fine.
|
||||
_stream_implementation = cast( # type: ignore
|
||||
Callable[
|
||||
[
|
||||
List[BaseMessage],
|
||||
Optional[List[str]],
|
||||
CallbackManagerForLLMRun,
|
||||
Any,
|
||||
],
|
||||
AsyncIterator[ChatGenerationChunk],
|
||||
],
|
||||
_as_async_iterator(self._stream),
|
||||
)
|
||||
else: # No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
BaseMessageChunk,
|
||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||
)
|
||||
return
|
||||
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
async for chunk in _stream_implementation(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(generations=[[generation]] if generation else []),
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
config.get("tags"),
|
||||
self.tags,
|
||||
config.get("metadata"),
|
||||
self.metadata,
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
dumpd(self),
|
||||
[messages],
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
async for chunk in self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
except BaseException as e:
|
||||
await run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(
|
||||
generations=[[generation]] if generation else []
|
||||
),
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_llm_end(
|
||||
LLMResult(generations=[[generation]]),
|
||||
)
|
||||
|
||||
# --- Custom methods ---
|
||||
|
||||
|
@ -1,8 +1,19 @@
|
||||
"""Test base chat model."""
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
@ -106,3 +117,111 @@ async def test_stream_error_callback() -> None:
|
||||
pass
|
||||
|
||||
eval_response(cb_sync, i)
|
||||
|
||||
|
||||
async def test_astream_fallback_to_ainvoke() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithGenerate(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
message = AIMessage(content="hello")
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
model = ModelWithGenerate()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [AIMessage(content="hello")]
|
||||
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [AIMessage(content="hello")]
|
||||
|
||||
|
||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithSyncStream(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="a"))
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="b"))
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
model = ModelWithSyncStream()
|
||||
chunks = [chunk for chunk in model.stream("anything")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="a"),
|
||||
AIMessageChunk(content="b"),
|
||||
]
|
||||
assert type(model)._astream == BaseChatModel._astream
|
||||
astream_chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert astream_chunks == [
|
||||
AIMessageChunk(content="a"),
|
||||
AIMessageChunk(content="b"),
|
||||
]
|
||||
|
||||
|
||||
async def test_astream_implementation_uses_astream() -> None:
|
||||
"""Test astream uses appropriate implementation."""
|
||||
|
||||
class ModelWithAsyncStream(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _astream( # type: ignore
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="a"))
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content="b"))
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
model = ModelWithAsyncStream()
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="a"),
|
||||
AIMessageChunk(content="b"),
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user