mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 22:56:05 +00:00
cr
This commit is contained in:
@@ -39,11 +39,9 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
self.done.clear()
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if token is not None and token != "":
|
||||
self.queue.put_nowait(str(token))
|
||||
self.queue.put_nowait(token)
|
||||
|
||||
@override
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
|
||||
@@ -81,11 +81,9 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.done.set()
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(str(token))
|
||||
self.append_to_last_tokens(token)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
@@ -97,4 +95,4 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
|
||||
# If yes, then put tokens from now on
|
||||
if self.answer_reached:
|
||||
self.queue.put_nowait(str(token))
|
||||
self.queue.put_nowait(token)
|
||||
|
||||
@@ -76,12 +76,10 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
self.answer_reached = False
|
||||
|
||||
@override
|
||||
def on_llm_new_token(
|
||||
self, token: str | list[str | dict[str, Any]], **kwargs: Any
|
||||
) -> None:
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
# Remember the last n tokens, where n = len(answer_prefix_tokens)
|
||||
self.append_to_last_tokens(str(token))
|
||||
self.append_to_last_tokens(token)
|
||||
|
||||
# Check if the last n tokens match the answer_prefix_tokens list ...
|
||||
if self.check_if_answer_reached():
|
||||
@@ -94,5 +92,5 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
|
||||
# ... if yes, then print tokens from now on
|
||||
if self.answer_reached:
|
||||
sys.stdout.write(str(token))
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
@@ -173,7 +173,7 @@ async def test_callback_handlers() -> None:
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str | list[str | dict[str, Any]],
|
||||
token: str,
|
||||
*,
|
||||
chunk: GenerationChunk | ChatGenerationChunk | None = None,
|
||||
run_id: UUID,
|
||||
@@ -181,7 +181,7 @@ async def test_callback_handlers() -> None:
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.store.append(str(token))
|
||||
self.store.append(token)
|
||||
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user