langchain[patch]: Invoke callback prior to yielding token (#18282)

## PR title
langchain[patch]: Invoke callback prior to yielding

## PR message
Description: Invoke on_llm_new_token callback prior to yielding token in
_stream and _astream methods in langchain/tests/fake_chat_model.
Issue: https://github.com/langchain-ai/langchain/issues/16913
Dependencies: None
Twitter handle: None
This commit is contained in:
William De Vena 2024-02-28 22:15:02 +01:00 committed by GitHub
parent cd52433ba0
commit 23722e3653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -120,9 +120,9 @@ class GenericFakeChatModel(BaseChatModel):
for token in content_chunks: for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk) run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
if message.additional_kwargs: if message.additional_kwargs:
for key, value in message.additional_kwargs.items(): for key, value in message.additional_kwargs.items():
@ -142,12 +142,12 @@ class GenericFakeChatModel(BaseChatModel):
}, },
) )
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
"", "",
chunk=chunk, # No token for function call chunk=chunk, # No token for function call
) )
yield chunk
else: else:
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
@ -155,24 +155,24 @@ class GenericFakeChatModel(BaseChatModel):
additional_kwargs={"function_call": {fkey: fvalue}}, additional_kwargs={"function_call": {fkey: fvalue}},
) )
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
"", "",
chunk=chunk, # No token for function call chunk=chunk, # No token for function call
) )
yield chunk
else: else:
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
content="", additional_kwargs={key: value} content="", additional_kwargs={key: value}
) )
) )
yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
"", "",
chunk=chunk, # No token for function call chunk=chunk, # No token for function call
) )
yield chunk
async def _astream( async def _astream(
self, self,