This commit is contained in:
Mason Daugherty
2026-02-21 02:05:48 -05:00
parent ea18e7f29f
commit 9adbd2fdb7
4 changed files with 10 additions and 16 deletions

View File

@@ -39,11 +39,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
self.done.clear()
@override
async def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token is not None and token != "":
self.queue.put_nowait(str(token))
self.queue.put_nowait(token)
@override
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:

View File

@@ -81,11 +81,9 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.done.set()
@override
async def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(str(token))
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
@@ -97,4 +95,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
# If yes, then put tokens from now on
if self.answer_reached:
self.queue.put_nowait(str(token))
self.queue.put_nowait(token)

View File

@@ -76,12 +76,10 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
self.answer_reached = False
@override
def on_llm_new_token(
self, token: str | list[str | dict[str, Any]], **kwargs: Any
) -> None:
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
# Remember the last n tokens, where n = len(answer_prefix_tokens)
self.append_to_last_tokens(str(token))
self.append_to_last_tokens(token)
# Check if the last n tokens match the answer_prefix_tokens list ...
if self.check_if_answer_reached():
@@ -94,5 +92,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
# ... if yes, then print tokens from now on
if self.answer_reached:
sys.stdout.write(str(token))
sys.stdout.write(token)
sys.stdout.flush()

View File

@@ -173,7 +173,7 @@ async def test_callback_handlers() -> None:
@override
async def on_llm_new_token(
self,
token: str | list[str | dict[str, Any]],
token: str,
*,
chunk: GenerationChunk | ChatGenerationChunk | None = None,
run_id: UUID,
@@ -181,7 +181,7 @@ async def test_callback_handlers() -> None:
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.store.append(str(token))
self.store.append(token)
infinite_cycle = cycle(
[