mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 06:26:12 +00:00
core: fix batch race condition in FakeListChatModel (#26924)
fixed #26273
This commit is contained in:
@@ -13,6 +13,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
class FakeMessagesListChatModel(BaseChatModel):
|
||||
@@ -128,6 +129,33 @@ class FakeListChatModel(SimpleChatModel):
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {"responses": self.responses}
|
||||
|
||||
# manually override batch to preserve batch ordering with no concurrency
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[Any],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseMessage]:
|
||||
if isinstance(config, list):
|
||||
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
|
||||
return [self.invoke(m, config, **kwargs) for m in inputs]
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: list[Any],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[BaseMessage]:
|
||||
if isinstance(config, list):
|
||||
# do Not use an async iterator here because need explicit ordering
|
||||
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
|
||||
# do Not use an async iterator here because need explicit ordering
|
||||
return [await self.ainvoke(m, config, **kwargs) for m in inputs]
|
||||
|
||||
|
||||
class FakeChatModel(SimpleChatModel):
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
|
Reference in New Issue
Block a user