mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
add default for async (#11367)
This commit is contained in:
parent
d21dd72d64
commit
6e848b879a
@ -170,12 +170,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseMessageChunk:
|
) -> BaseMessageChunk:
|
||||||
if type(self)._agenerate == BaseChatModel._agenerate:
|
|
||||||
# model doesn't implement async generation, so use default implementation
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self.invoke, input, config, stop=stop, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
config = config or {}
|
config = config or {}
|
||||||
llm_result = await self.agenerate_prompt(
|
llm_result = await self.agenerate_prompt(
|
||||||
[self._convert_input(input)],
|
[self._convert_input(input)],
|
||||||
@ -582,7 +576,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
raise NotImplementedError()
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
partial(
|
||||||
|
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user