core: fix batch race condition in FakeListChatModel (#26924)

fixed #26273
This commit is contained in:
Erick Friis
2024-10-03 16:14:31 -07:00
committed by GitHub
parent 87fc5ce688
commit ab4dab9a0c
3 changed files with 48 additions and 9 deletions

View File

@@ -13,6 +13,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
class FakeMessagesListChatModel(BaseChatModel):
@@ -128,6 +129,33 @@ class FakeListChatModel(SimpleChatModel):
def _identifying_params(self) -> dict[str, Any]:
return {"responses": self.responses}
# manually override batch to preserve batch ordering with no concurrency
def batch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
return [self.invoke(m, c, **kwargs) for m, c in zip(inputs, config)]
return [self.invoke(m, config, **kwargs) for m in inputs]
async def abatch(
self,
inputs: list[Any],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> list[BaseMessage]:
if isinstance(config, list):
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, c, **kwargs) for m, c in zip(inputs, config)]
# do Not use an async iterator here because need explicit ordering
return [await self.ainvoke(m, config, **kwargs) for m in inputs]
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""