mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
infra: fake model invoke callback prior to yielding token (#18286)
## PR title core[patch]: Invoke callback prior to yielding ## PR message Description: Invoke on_llm_new_token callback prior to yielding token in _stream and _astream methods. Issue: https://github.com/langchain-ai/langchain/issues/16913 Dependencies: None Twitter handle: None
This commit is contained in:
parent
31b4e78174
commit
42341bc787
@ -225,9 +225,9 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
|
|
||||||
for token in content_chunks:
|
for token in content_chunks:
|
||||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||||
yield chunk
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
if message.additional_kwargs:
|
if message.additional_kwargs:
|
||||||
for key, value in message.additional_kwargs.items():
|
for key, value in message.additional_kwargs.items():
|
||||||
@ -247,12 +247,12 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield chunk
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
"",
|
"",
|
||||||
chunk=chunk, # No token for function call
|
chunk=chunk, # No token for function call
|
||||||
)
|
)
|
||||||
|
yield chunk
|
||||||
else:
|
else:
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
@ -260,24 +260,24 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
additional_kwargs={"function_call": {fkey: fvalue}},
|
additional_kwargs={"function_call": {fkey: fvalue}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield chunk
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
"",
|
"",
|
||||||
chunk=chunk, # No token for function call
|
chunk=chunk, # No token for function call
|
||||||
)
|
)
|
||||||
|
yield chunk
|
||||||
else:
|
else:
|
||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
content="", additional_kwargs={key: value}
|
content="", additional_kwargs={key: value}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield chunk
|
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
"",
|
"",
|
||||||
chunk=chunk, # No token for function call
|
chunk=chunk, # No token for function call
|
||||||
)
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
|
@ -398,22 +398,14 @@ async def test_event_stream_with_simple_chain() -> None:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"foo": "bar"},
|
"metadata": {"a": "b", "foo": "bar"},
|
||||||
"name": "my_chain",
|
"name": "my_model",
|
||||||
"run_id": "",
|
"run_id": "",
|
||||||
"tags": ["my_chain"],
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="hello")},
|
"data": {"chunk": AIMessageChunk(content="hello")},
|
||||||
"event": "on_chat_model_stream",
|
|
||||||
"metadata": {"a": "b", "foo": "bar"},
|
|
||||||
"name": "my_model",
|
|
||||||
"run_id": "",
|
|
||||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {"chunk": AIMessageChunk(content=" ")},
|
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {"foo": "bar"},
|
"metadata": {"foo": "bar"},
|
||||||
"name": "my_chain",
|
"name": "my_chain",
|
||||||
@ -429,7 +421,7 @@ async def test_event_stream_with_simple_chain() -> None:
|
|||||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {"chunk": AIMessageChunk(content="world!")},
|
"data": {"chunk": AIMessageChunk(content=" ")},
|
||||||
"event": "on_chain_stream",
|
"event": "on_chain_stream",
|
||||||
"metadata": {"foo": "bar"},
|
"metadata": {"foo": "bar"},
|
||||||
"name": "my_chain",
|
"name": "my_chain",
|
||||||
@ -444,6 +436,14 @@ async def test_event_stream_with_simple_chain() -> None:
|
|||||||
"run_id": "",
|
"run_id": "",
|
||||||
"tags": ["my_chain", "my_model", "seq:step:2"],
|
"tags": ["my_chain", "my_model", "seq:step:2"],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"data": {"chunk": AIMessageChunk(content="world!")},
|
||||||
|
"event": "on_chain_stream",
|
||||||
|
"metadata": {"foo": "bar"},
|
||||||
|
"name": "my_chain",
|
||||||
|
"run_id": "",
|
||||||
|
"tags": ["my_chain"],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"input": {
|
"input": {
|
||||||
|
Loading…
Reference in New Issue
Block a user