diff --git a/libs/core/tests/benchmarks/test_async_callbacks.py b/libs/core/tests/benchmarks/test_async_callbacks.py index 0a3c9127dc4..d07224c375e 100644 --- a/libs/core/tests/benchmarks/test_async_callbacks.py +++ b/libs/core/tests/benchmarks/test_async_callbacks.py @@ -1,40 +1,44 @@ -# ruff: noqa: ARG002 import asyncio from itertools import cycle -from typing import Any +from typing import Any, Optional, Union +from uuid import UUID import pytest -from pytest_benchmark.fixture import BenchmarkFixture # type: ignore +from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped] +from typing_extensions import override from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.language_models import GenericFakeChatModel -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatGenerationChunk, GenerationChunk class MyCustomAsyncHandler(AsyncCallbackHandler): + @override async def on_chat_model_start( self, - serialized: Any, - messages: Any, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], *, - run_id: Any, - parent_run_id: Any = None, - tags: Any = None, - metadata: Any = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: # Do nothing # Required to implement since this is an abstract method pass + @override async def on_llm_new_token( self, token: str, *, - chunk: Any = None, - run_id: Any, - parent_run_id: Any = None, - tags: Any = None, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, **kwargs: Any, ) -> None: await asyncio.sleep(0)