mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
core: Simplify astream logic in BaseChatModel and BaseLLM (#19332)
Covered by tests in `libs/core/tests/unit_tests/language_models/chat_models/test_base.py`, `libs/core/tests/unit_tests/language_models/llms/test_base.py` and `libs/core/tests/unit_tests/runnables/test_runnable_events.py`
This commit is contained in:
parent
40f846e65d
commit
4c2e887276
@ -8,9 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@ -99,26 +97,6 @@ async def agenerate_from_stream(
|
||||
)
|
||||
|
||||
|
||||
def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
||||
"""Convert a sync iterator into an async iterator."""
|
||||
|
||||
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(
|
||||
None,
|
||||
next,
|
||||
iterator,
|
||||
done, # type: ignore[call-arg, arg-type]
|
||||
)
|
||||
if item is done:
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
return _as_sync_iterator
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
@ -270,28 +248,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream is not BaseChatModel._astream:
|
||||
# Then astream is implemented
|
||||
_stream_implementation = self._astream
|
||||
using_sync_stream = False
|
||||
elif type(self)._stream is not BaseChatModel._stream:
|
||||
# Then stream is implemented, so we can create an async iterator from it
|
||||
# The typing is hard to type correctly with mypy here, so we cast
|
||||
# and do a type ignore, this code is unit tested and should be fine.
|
||||
_stream_implementation = cast( # type: ignore
|
||||
Callable[
|
||||
[
|
||||
List[BaseMessage],
|
||||
Optional[List[str]],
|
||||
CallbackManagerForLLMRun,
|
||||
Any,
|
||||
],
|
||||
AsyncIterator[ChatGenerationChunk],
|
||||
],
|
||||
_as_async_iterator(self._stream),
|
||||
)
|
||||
using_sync_stream = True
|
||||
else: # No async or sync stream is implemented, so fall back to ainvoke
|
||||
if (
|
||||
type(self)._astream is BaseChatModel._astream
|
||||
and type(self)._stream is BaseChatModel._stream
|
||||
):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
BaseMessageChunk,
|
||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||
@ -321,13 +282,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
async for chunk in _stream_implementation(
|
||||
async for chunk in self._astream(
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager_, # type: ignore[arg-type]
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
@ -731,14 +691,32 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _astream(
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
iterator = await run_in_executor(
|
||||
None,
|
||||
self._stream,
|
||||
messages,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(
|
||||
None,
|
||||
next,
|
||||
iterator,
|
||||
done, # type: ignore[call-arg, arg-type]
|
||||
)
|
||||
if item is done:
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
||||
def __call__(
|
||||
|
@ -13,7 +13,6 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
@ -116,26 +115,6 @@ def create_base_retry_decorator(
|
||||
)
|
||||
|
||||
|
||||
def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
||||
"""Convert a sync iterator into an async iterator."""
|
||||
|
||||
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(
|
||||
None,
|
||||
next,
|
||||
iterator,
|
||||
done, # type: ignore[call-arg, arg-type]
|
||||
)
|
||||
if item is done:
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
return _as_sync_iterator
|
||||
|
||||
|
||||
def get_prompts(
|
||||
params: Dict[str, Any], prompts: List[str]
|
||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||
@ -460,28 +439,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[str]:
|
||||
if type(self)._astream is not BaseLLM._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
_stream_implementation = self._astream
|
||||
using_sync_stream = False
|
||||
elif type(self)._stream is not BaseLLM._stream:
|
||||
# Then stream is implemented, so we can create an async iterator from it
|
||||
# The typing is hard to type correctly with mypy here, so we cast
|
||||
# and do a type ignore, this code is unit tested and should be fine.
|
||||
_stream_implementation = cast( # type: ignore
|
||||
Callable[
|
||||
[
|
||||
str,
|
||||
Optional[List[str]],
|
||||
CallbackManagerForLLMRun,
|
||||
Any,
|
||||
],
|
||||
AsyncIterator[GenerationChunk],
|
||||
],
|
||||
_as_async_iterator(self._stream),
|
||||
)
|
||||
using_sync_stream = True
|
||||
else:
|
||||
if (
|
||||
type(self)._astream is BaseLLM._astream
|
||||
and type(self)._stream is BaseLLM._stream
|
||||
):
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
return
|
||||
|
||||
@ -509,13 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
)
|
||||
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
|
||||
generation: Optional[GenerationChunk] = None
|
||||
try:
|
||||
async for chunk in _stream_implementation(
|
||||
async for chunk in self._astream(
|
||||
prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager_, # type: ignore[arg-type]
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
yield chunk.text
|
||||
@ -571,14 +531,32 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
) -> Iterator[GenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _astream(
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
raise NotImplementedError()
|
||||
iterator = await run_in_executor(
|
||||
None,
|
||||
self._stream,
|
||||
prompt,
|
||||
stop,
|
||||
run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(
|
||||
None,
|
||||
next,
|
||||
iterator,
|
||||
done, # type: ignore[call-arg, arg-type]
|
||||
)
|
||||
if item is done:
|
||||
break
|
||||
yield item # type: ignore[misc]
|
||||
|
||||
def generate_prompt(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user