mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
core: Simplify astream logic in BaseChatModel and BaseLLM (#19332)
Covered by tests in `libs/core/tests/unit_tests/language_models/chat_models/test_base.py`, `libs/core/tests/unit_tests/language_models/llms/test_base.py` and `libs/core/tests/unit_tests/runnables/test_runnable_events.py`
This commit is contained in:
parent
40f846e65d
commit
4c2e887276
@ -8,9 +8,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@ -99,26 +97,6 @@ async def agenerate_from_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
|
||||||
"""Convert a sync iterator into an async iterator."""
|
|
||||||
|
|
||||||
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
|
||||||
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
|
|
||||||
done = object()
|
|
||||||
while True:
|
|
||||||
item = await run_in_executor(
|
|
||||||
None,
|
|
||||||
next,
|
|
||||||
iterator,
|
|
||||||
done, # type: ignore[call-arg, arg-type]
|
|
||||||
)
|
|
||||||
if item is done:
|
|
||||||
break
|
|
||||||
yield item # type: ignore[misc]
|
|
||||||
|
|
||||||
return _as_sync_iterator
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||||
"""Base class for Chat models."""
|
"""Base class for Chat models."""
|
||||||
|
|
||||||
@ -270,28 +248,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[BaseMessageChunk]:
|
) -> AsyncIterator[BaseMessageChunk]:
|
||||||
if type(self)._astream is not BaseChatModel._astream:
|
if (
|
||||||
# Then astream is implemented
|
type(self)._astream is BaseChatModel._astream
|
||||||
_stream_implementation = self._astream
|
and type(self)._stream is BaseChatModel._stream
|
||||||
using_sync_stream = False
|
):
|
||||||
elif type(self)._stream is not BaseChatModel._stream:
|
# No async or sync stream is implemented, so fall back to ainvoke
|
||||||
# Then stream is implemented, so we can create an async iterator from it
|
|
||||||
# The typing is hard to type correctly with mypy here, so we cast
|
|
||||||
# and do a type ignore, this code is unit tested and should be fine.
|
|
||||||
_stream_implementation = cast( # type: ignore
|
|
||||||
Callable[
|
|
||||||
[
|
|
||||||
List[BaseMessage],
|
|
||||||
Optional[List[str]],
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
Any,
|
|
||||||
],
|
|
||||||
AsyncIterator[ChatGenerationChunk],
|
|
||||||
],
|
|
||||||
_as_async_iterator(self._stream),
|
|
||||||
)
|
|
||||||
using_sync_stream = True
|
|
||||||
else: # No async or sync stream is implemented, so fall back to ainvoke
|
|
||||||
yield cast(
|
yield cast(
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||||
@ -321,13 +282,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
|
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
try:
|
try:
|
||||||
async for chunk in _stream_implementation(
|
async for chunk in self._astream(
|
||||||
messages,
|
messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
run_manager=run_manager_, # type: ignore[arg-type]
|
run_manager=run_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||||
@ -731,14 +691,32 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
raise NotImplementedError()
|
iterator = await run_in_executor(
|
||||||
|
None,
|
||||||
|
self._stream,
|
||||||
|
messages,
|
||||||
|
stop,
|
||||||
|
run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
done = object()
|
||||||
|
while True:
|
||||||
|
item = await run_in_executor(
|
||||||
|
None,
|
||||||
|
next,
|
||||||
|
iterator,
|
||||||
|
done, # type: ignore[call-arg, arg-type]
|
||||||
|
)
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item # type: ignore[misc]
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -13,7 +13,6 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
@ -116,26 +115,6 @@ def create_base_retry_decorator(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _as_async_iterator(sync_iterator: Callable) -> Callable:
|
|
||||||
"""Convert a sync iterator into an async iterator."""
|
|
||||||
|
|
||||||
async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
|
||||||
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
|
|
||||||
done = object()
|
|
||||||
while True:
|
|
||||||
item = await run_in_executor(
|
|
||||||
None,
|
|
||||||
next,
|
|
||||||
iterator,
|
|
||||||
done, # type: ignore[call-arg, arg-type]
|
|
||||||
)
|
|
||||||
if item is done:
|
|
||||||
break
|
|
||||||
yield item # type: ignore[misc]
|
|
||||||
|
|
||||||
return _as_sync_iterator
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompts(
|
def get_prompts(
|
||||||
params: Dict[str, Any], prompts: List[str]
|
params: Dict[str, Any], prompts: List[str]
|
||||||
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
|
||||||
@ -460,28 +439,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
if type(self)._astream is not BaseLLM._astream:
|
if (
|
||||||
# model doesn't implement streaming, so use default implementation
|
type(self)._astream is BaseLLM._astream
|
||||||
_stream_implementation = self._astream
|
and type(self)._stream is BaseLLM._stream
|
||||||
using_sync_stream = False
|
):
|
||||||
elif type(self)._stream is not BaseLLM._stream:
|
|
||||||
# Then stream is implemented, so we can create an async iterator from it
|
|
||||||
# The typing is hard to type correctly with mypy here, so we cast
|
|
||||||
# and do a type ignore, this code is unit tested and should be fine.
|
|
||||||
_stream_implementation = cast( # type: ignore
|
|
||||||
Callable[
|
|
||||||
[
|
|
||||||
str,
|
|
||||||
Optional[List[str]],
|
|
||||||
CallbackManagerForLLMRun,
|
|
||||||
Any,
|
|
||||||
],
|
|
||||||
AsyncIterator[GenerationChunk],
|
|
||||||
],
|
|
||||||
_as_async_iterator(self._stream),
|
|
||||||
)
|
|
||||||
using_sync_stream = True
|
|
||||||
else:
|
|
||||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -509,13 +470,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
run_manager_ = run_manager.get_sync() if using_sync_stream else run_manager
|
|
||||||
generation: Optional[GenerationChunk] = None
|
generation: Optional[GenerationChunk] = None
|
||||||
try:
|
try:
|
||||||
async for chunk in _stream_implementation(
|
async for chunk in self._astream(
|
||||||
prompt,
|
prompt,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
run_manager=run_manager_, # type: ignore[arg-type]
|
run_manager=run_manager,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
@ -571,14 +531,32 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[GenerationChunk]:
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
raise NotImplementedError()
|
iterator = await run_in_executor(
|
||||||
|
None,
|
||||||
|
self._stream,
|
||||||
|
prompt,
|
||||||
|
stop,
|
||||||
|
run_manager.get_sync() if run_manager else None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
done = object()
|
||||||
|
while True:
|
||||||
|
item = await run_in_executor(
|
||||||
|
None,
|
||||||
|
next,
|
||||||
|
iterator,
|
||||||
|
done, # type: ignore[call-arg, arg-type]
|
||||||
|
)
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item # type: ignore[misc]
|
||||||
|
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user