diff --git a/libs/langchain/langchain_classic/callbacks/streaming_aiter.py b/libs/langchain/langchain_classic/callbacks/streaming_aiter.py index a0aa62083ff..0811e86a3e7 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_aiter.py @@ -39,11 +39,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): self.done.clear() @override - async def on_llm_new_token( - self, token: str | list[str | dict[str, Any]], **kwargs: Any - ) -> None: + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: if token is not None and token != "": - self.queue.put_nowait(str(token)) + self.queue.put_nowait(token) @override async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: diff --git a/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py b/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py index b9661639ca7..744475e6628 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_aiter_final_only.py @@ -81,11 +81,9 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): self.done.set() @override - async def on_llm_new_token( - self, token: str | list[str | dict[str, Any]], **kwargs: Any - ) -> None: + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.append_to_last_tokens(str(token)) + self.append_to_last_tokens(token) # Check if the last n tokens match the answer_prefix_tokens list ... if self.check_if_answer_reached(): @@ -97,4 +95,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler): # If yes, then put tokens from now on if self.answer_reached: - self.queue.put_nowait(str(token)) + self.queue.put_nowait(token) diff --git a/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py b/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py index 48af5a7b5f5..e8eee519b3e 100644 --- a/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py +++ b/libs/langchain/langchain_classic/callbacks/streaming_stdout_final_only.py @@ -76,12 +76,10 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): self.answer_reached = False @override - def on_llm_new_token( - self, token: str | list[str | dict[str, Any]], **kwargs: Any - ) -> None: + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" # Remember the last n tokens, where n = len(answer_prefix_tokens) - self.append_to_last_tokens(str(token)) + self.append_to_last_tokens(token) # Check if the last n tokens match the answer_prefix_tokens list ... if self.check_if_answer_reached(): @@ -94,5 +92,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): # ... if yes, then print tokens from now on if self.answer_reached: - sys.stdout.write(str(token)) + sys.stdout.write(token) sys.stdout.flush() diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index d467cab89d7..88cd91c3d13 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -173,7 +173,7 @@ async def test_callback_handlers() -> None: @override async def on_llm_new_token( self, - token: str | list[str | dict[str, Any]], + token: str, *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, @@ -181,7 +181,7 @@ async def test_callback_handlers() -> None: tags: list[str] | None = None, **kwargs: Any, ) -> None: - self.store.append(str(token)) + self.store.append(token) infinite_cycle = cycle( [