mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-16 12:32:06 +00:00
langchain[patch]: Improve stream_log with AgentExecutor and Runnable Agent (#15792)
This PR fixes an issue where AgentExecutor with RunnableAgent does not allow users to see individual llm tokens if streaming=True is not set explicitly on the underlying chat model. The majority of this PR is testing code: 1. Create a test chat model that makes it easier to test streaming and supports AIMessages that include function invocation information. 2. Tests for the chat model 3. Tests for RunnableAgent (previously untested) 4. Tests for openai agent (previously untested)
This commit is contained in:
parent
85a4594ed7
commit
feb41c5e28
@ -371,7 +371,7 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
"""Based on past history and current inputs, decide what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
@ -383,8 +383,19 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
|
||||
return output
|
||||
# Use streaming to make sure that the underlying LLM is invoked in a streaming
|
||||
# fashion to make it possible to get access to the individual LLM tokens
|
||||
# when using stream_log with the Agent Executor.
|
||||
# Because the response from the plan is not a generator, we need to
|
||||
# accumulate the output into final output and return that.
|
||||
final_output: Any = None
|
||||
for chunk in self.runnable.stream(inputs, config={"callbacks": callbacks}):
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
final_output += chunk
|
||||
|
||||
return final_output
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
@ -395,20 +406,32 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
]:
|
||||
"""Given input, decided what to do.
|
||||
"""Based on past history and current inputs, decide what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
**kwargs: User inputs
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
|
||||
return output
|
||||
final_output: Any = None
|
||||
# Use streaming to make sure that the underlying LLM is invoked in a streaming
|
||||
# fashion to make it possible to get access to the individual LLM tokens
|
||||
# when using stream_log with the Agent Executor.
|
||||
# Because the response from the plan is not a generator, we need to
|
||||
# accumulate the output into final output and return that.
|
||||
async for chunk in self.runnable.astream(
|
||||
inputs, config={"callbacks": callbacks}
|
||||
):
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
final_output += chunk
|
||||
return final_output
|
||||
|
||||
|
||||
class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
@ -447,7 +470,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
List[AgentAction],
|
||||
AgentFinish,
|
||||
]:
|
||||
"""Given input, decided what to do.
|
||||
"""Based on past history and current inputs, decide what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
@ -459,8 +482,19 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
|
||||
return output
|
||||
# Use streaming to make sure that the underlying LLM is invoked in a streaming
|
||||
# fashion to make it possible to get access to the individual LLM tokens
|
||||
# when using stream_log with the Agent Executor.
|
||||
# Because the response from the plan is not a generator, we need to
|
||||
# accumulate the output into final output and return that.
|
||||
final_output: Any = None
|
||||
for chunk in self.runnable.stream(inputs, config={"callbacks": callbacks}):
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
final_output += chunk
|
||||
|
||||
return final_output
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
@ -471,7 +505,7 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
List[AgentAction],
|
||||
AgentFinish,
|
||||
]:
|
||||
"""Given input, decided what to do.
|
||||
"""Based on past history and current inputs, decide what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
@ -483,8 +517,21 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
|
||||
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
|
||||
return output
|
||||
# Use streaming to make sure that the underlying LLM is invoked in a streaming
|
||||
# fashion to make it possible to get access to the individual LLM tokens
|
||||
# when using stream_log with the Agent Executor.
|
||||
# Because the response from the plan is not a generator, we need to
|
||||
# accumulate the output into final output and return that.
|
||||
final_output: Any = None
|
||||
async for chunk in self.runnable.astream(
|
||||
inputs, config={"callbacks": callbacks}
|
||||
):
|
||||
if final_output is None:
|
||||
final_output = chunk
|
||||
else:
|
||||
final_output += chunk
|
||||
|
||||
return final_output
|
||||
|
||||
|
||||
@deprecated(
|
||||
|
@ -1,16 +1,37 @@
|
||||
"""Unit tests for agents."""
|
||||
import json
|
||||
from itertools import cycle
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentStep
|
||||
from langchain_core.agents import (
|
||||
AgentAction,
|
||||
AgentActionMessageLog,
|
||||
AgentFinish,
|
||||
AgentStep,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain_core.prompts import MessagesPlaceholder
|
||||
from langchain_core.runnables.utils import add
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_core.tracers import RunLog, RunLogPatch
|
||||
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
AgentType,
|
||||
create_openai_functions_agent,
|
||||
initialize_agent,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.tools import tool
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
@ -414,3 +435,356 @@ def test_agent_invalid_tool() -> None:
|
||||
|
||||
resp = agent("when was langchain made")
|
||||
resp["intermediate_steps"][0][1] == "Foo is not a valid tool, try one of [Search]."
|
||||
|
||||
|
||||
async def test_runnable_agent() -> None:
|
||||
"""Simple test to verify that an agent built with LCEL works."""
|
||||
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello world!")])
|
||||
# When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
||||
)
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message")
|
||||
|
||||
agent = template | model | fake_parse
|
||||
executor = AgentExecutor(agent=agent, tools=[])
|
||||
|
||||
# Invoke
|
||||
result = executor.invoke({"question": "hello"})
|
||||
assert result == {"foo": "meow", "question": "hello"}
|
||||
|
||||
# ainvoke
|
||||
result = await executor.ainvoke({"question": "hello"})
|
||||
assert result == {"foo": "meow", "question": "hello"}
|
||||
|
||||
# Batch
|
||||
result = executor.batch( # type: ignore[assignment]
|
||||
[{"question": "hello"}, {"question": "hello"}]
|
||||
)
|
||||
assert result == [
|
||||
{"foo": "meow", "question": "hello"},
|
||||
{"foo": "meow", "question": "hello"},
|
||||
]
|
||||
|
||||
# abatch
|
||||
result = await executor.abatch( # type: ignore[assignment]
|
||||
[{"question": "hello"}, {"question": "hello"}]
|
||||
)
|
||||
assert result == [
|
||||
{"foo": "meow", "question": "hello"},
|
||||
{"foo": "meow", "question": "hello"},
|
||||
]
|
||||
|
||||
# Stream
|
||||
results = list(executor.stream({"question": "hello"}))
|
||||
assert results == [
|
||||
{"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]}
|
||||
]
|
||||
|
||||
# astream
|
||||
results = [r async for r in executor.astream({"question": "hello"})]
|
||||
assert results == [
|
||||
{
|
||||
"foo": "meow",
|
||||
"messages": [
|
||||
AIMessage(content="hard-coded-message"),
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# stream log
|
||||
results: List[RunLogPatch] = [ # type: ignore[no-redef]
|
||||
r async for r in executor.astream_log({"question": "hello"})
|
||||
]
|
||||
# # Let's stream just the llm tokens.
|
||||
messages = []
|
||||
for log_record in results:
|
||||
for op in log_record.ops: # type: ignore[attr-defined]
|
||||
if op["op"] == "add" and isinstance(op["value"], AIMessageChunk):
|
||||
messages.append(op["value"])
|
||||
|
||||
assert messages != []
|
||||
|
||||
# Aggregate state
|
||||
run_log = None
|
||||
|
||||
for result in results:
|
||||
if run_log is None:
|
||||
run_log = result
|
||||
else:
|
||||
# `+` is defined for RunLogPatch
|
||||
run_log = run_log + result # type: ignore[union-attr]
|
||||
|
||||
assert isinstance(run_log, RunLog)
|
||||
|
||||
assert run_log.state["final_output"] == {
|
||||
"foo": "meow",
|
||||
"messages": [AIMessage(content="hard-coded-message")],
|
||||
}
|
||||
|
||||
|
||||
async def test_runnable_agent_with_function_calls() -> None:
|
||||
"""Test agent with intermediate agent actions."""
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle(
|
||||
[AIMessage(content="looking for pet..."), AIMessage(content="Found Pet")]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[("system", "You are Cat Agent 007"), ("human", "{question}")]
|
||||
)
|
||||
|
||||
parser_responses = cycle(
|
||||
[
|
||||
AgentAction(
|
||||
tool="find_pet",
|
||||
tool_input={
|
||||
"pet": "cat",
|
||||
},
|
||||
log="find_pet()",
|
||||
),
|
||||
AgentFinish(
|
||||
return_values={"foo": "meow"},
|
||||
log="hard-coded-message",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return cast(Union[AgentFinish, AgentAction], next(parser_responses))
|
||||
|
||||
@tool
|
||||
def find_pet(pet: str) -> str:
|
||||
"""Find the given pet."""
|
||||
if pet != "cat":
|
||||
raise ValueError("Only cats allowed")
|
||||
return "Spying from under the bed."
|
||||
|
||||
agent = template | model | fake_parse
|
||||
executor = AgentExecutor(agent=agent, tools=[find_pet])
|
||||
|
||||
# Invoke
|
||||
result = executor.invoke({"question": "hello"})
|
||||
assert result == {"foo": "meow", "question": "hello"}
|
||||
|
||||
# ainvoke
|
||||
result = await executor.ainvoke({"question": "hello"})
|
||||
assert result == {"foo": "meow", "question": "hello"}
|
||||
|
||||
# astream
|
||||
results = [r async for r in executor.astream({"question": "hello"})]
|
||||
assert results == [
|
||||
{
|
||||
"actions": [
|
||||
AgentAction(
|
||||
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
|
||||
)
|
||||
],
|
||||
"messages": [AIMessage(content="find_pet()")],
|
||||
},
|
||||
{
|
||||
"messages": [HumanMessage(content="Spying from under the bed.")],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=AgentAction(
|
||||
tool="find_pet", tool_input={"pet": "cat"}, log="find_pet()"
|
||||
),
|
||||
observation="Spying from under the bed.",
|
||||
)
|
||||
],
|
||||
},
|
||||
{"foo": "meow", "messages": [AIMessage(content="hard-coded-message")]},
|
||||
]
|
||||
|
||||
# astream log
|
||||
|
||||
messages = []
|
||||
async for patch in executor.astream_log({"question": "hello"}):
|
||||
for op in patch.ops:
|
||||
if op["op"] != "add":
|
||||
continue
|
||||
|
||||
value = op["value"]
|
||||
|
||||
if not isinstance(value, AIMessageChunk):
|
||||
continue
|
||||
|
||||
if value.content == "": # Then it's a function invocation message
|
||||
continue
|
||||
|
||||
messages.append(value.content)
|
||||
|
||||
assert messages == ["looking", " ", "for", " ", "pet...", "Found", " ", "Pet"]
|
||||
|
||||
|
||||
def _make_func_invocation(name: str, **kwargs: Any) -> AIMessage:
|
||||
"""Create an AIMessage that represents a function invocation.
|
||||
|
||||
Args:
|
||||
name: Name of the function to invoke.
|
||||
kwargs: Keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
AIMessage that represents a request to invoke a function.
|
||||
"""
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": name,
|
||||
"arguments": json.dumps(kwargs),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_openai_agent_with_streaming() -> None:
|
||||
"""Test openai agent with streaming."""
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
_make_func_invocation("find_pet", pet="cat"),
|
||||
AIMessage(content="The cat is spying from under the bed."),
|
||||
]
|
||||
)
|
||||
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
@tool
|
||||
def find_pet(pet: str) -> str:
|
||||
"""Find the given pet."""
|
||||
if pet != "cat":
|
||||
raise ValueError("Only cats allowed")
|
||||
return "Spying from under the bed."
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful AI bot. Your name is kitty power meow."),
|
||||
("human", "{question}"),
|
||||
MessagesPlaceholder(
|
||||
variable_name="agent_scratchpad",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# type error due to base tool type below -- would need to be adjusted on tool
|
||||
# decorator.
|
||||
agent = create_openai_functions_agent(
|
||||
model,
|
||||
[find_pet], # type: ignore[list-item]
|
||||
template,
|
||||
)
|
||||
executor = AgentExecutor(agent=agent, tools=[find_pet])
|
||||
|
||||
# Invoke
|
||||
result = executor.invoke({"question": "hello"})
|
||||
assert result == {
|
||||
"output": "The cat is spying from under the bed.",
|
||||
"question": "hello",
|
||||
}
|
||||
|
||||
# astream
|
||||
chunks = [chunk async for chunk in executor.astream({"question": "hello"})]
|
||||
assert chunks == [
|
||||
{
|
||||
"actions": [
|
||||
AgentActionMessageLog(
|
||||
tool="find_pet",
|
||||
tool_input={"pet": "cat"},
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
"messages": [
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
FunctionMessage(content="Spying from under the bed.", name="find_pet")
|
||||
],
|
||||
"steps": [
|
||||
AgentStep(
|
||||
action=AgentActionMessageLog(
|
||||
tool="find_pet",
|
||||
tool_input={"pet": "cat"},
|
||||
log="\nInvoking: `find_pet` with `{'pet': 'cat'}`\n\n\n",
|
||||
message_log=[
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "find_pet",
|
||||
"arguments": '{"pet": "cat"}',
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
),
|
||||
observation="Spying from under the bed.",
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"messages": [AIMessage(content="The cat is spying from under the bed.")],
|
||||
"output": "The cat is spying from under the bed.",
|
||||
},
|
||||
]
|
||||
#
|
||||
# # astream_log
|
||||
log_patches = [
|
||||
log_patch async for log_patch in executor.astream_log({"question": "hello"})
|
||||
]
|
||||
|
||||
messages = []
|
||||
|
||||
for log_patch in log_patches:
|
||||
for op in log_patch.ops:
|
||||
if op["op"] == "add" and isinstance(op["value"], AIMessageChunk):
|
||||
value = op["value"]
|
||||
if value.content: # Filter out function call messages
|
||||
messages.append(value.content)
|
||||
|
||||
assert messages == [
|
||||
"The",
|
||||
" ",
|
||||
"cat",
|
||||
" ",
|
||||
"is",
|
||||
" ",
|
||||
"spying",
|
||||
" ",
|
||||
"from",
|
||||
" ",
|
||||
"under",
|
||||
" ",
|
||||
"the",
|
||||
" ",
|
||||
"bed.",
|
||||
]
|
||||
|
@ -1,9 +1,15 @@
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
|
||||
|
||||
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -42,3 +48,151 @@ class FakeChatModel(SimpleChatModel):
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {"key": "fake"}
|
||||
|
||||
|
||||
class GenericFakeChatModel(BaseChatModel):
|
||||
"""A generic fake chat model that can be used to test the chat model interface.
|
||||
|
||||
* Chat model should be usable in both sync and async tests
|
||||
* Invokes on_llm_new_token to allow for testing of callback related code for new
|
||||
tokens.
|
||||
* Includes logic to break messages into message chunk to facilitate testing of
|
||||
streaming.
|
||||
"""
|
||||
|
||||
messages: Iterator[AIMessage]
|
||||
"""Get an iterator over messages.
|
||||
|
||||
This can be expanded to accept other types like Callables / dicts / strings
|
||||
to make the interface more generic if needed.
|
||||
|
||||
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
|
||||
|
||||
Please note that streaming is not implemented yet. We should try to implement it
|
||||
in the future by delegating to invoke and then breaking the resulting output
|
||||
into message chunks.
|
||||
"""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
message = next(self.messages)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
chat_result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if not isinstance(chat_result, ChatResult):
|
||||
raise ValueError(
|
||||
f"Expected generate to return a ChatResult, "
|
||||
f"but got {type(chat_result)} instead."
|
||||
)
|
||||
|
||||
message = chat_result.generations[0].message
|
||||
|
||||
if not isinstance(message, AIMessage):
|
||||
raise ValueError(
|
||||
f"Expected invoke to return an AIMessage, "
|
||||
f"but got {type(message)} instead."
|
||||
)
|
||||
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
# Use a regular expression to split on whitespace with a capture group
|
||||
# so that we can preserve the whitespace in the output.
|
||||
assert isinstance(content, str)
|
||||
content_chunks = cast(List[str], re.split(r"(\s)", content))
|
||||
|
||||
for token in content_chunks:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token, chunk=chunk)
|
||||
|
||||
if message.additional_kwargs:
|
||||
for key, value in message.additional_kwargs.items():
|
||||
# We should further break down the additional kwargs into chunks
|
||||
# Special case for function call
|
||||
if key == "function_call":
|
||||
for fkey, fvalue in value.items():
|
||||
if isinstance(fvalue, str):
|
||||
# Break function call by `,`
|
||||
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
|
||||
for fvalue_chunk in fvalue_chunks:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {fkey: fvalue_chunk}
|
||||
},
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
"",
|
||||
chunk=chunk, # No token for function call
|
||||
)
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {fkey: fvalue}},
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
"",
|
||||
chunk=chunk, # No token for function call
|
||||
)
|
||||
else:
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content="", additional_kwargs={key: value}
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
"",
|
||||
chunk=chunk, # No token for function call
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Stream the output of the model."""
|
||||
result = await run_in_executor(
|
||||
None,
|
||||
self._stream,
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
for chunk in result:
|
||||
yield chunk
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "generic-fake-chat-model"
|
||||
|
185
libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py
Normal file
185
libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
from itertools import cycle
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
|
||||
|
||||
def test_generic_fake_chat_model_invoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
response = model.invoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
response = model.invoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_ainvoke() -> None:
|
||||
# Will alternate between responding with hello and goodbye
|
||||
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
response = await model.ainvoke("kitty")
|
||||
assert response == AIMessage(content="goodbye")
|
||||
response = await model.ainvoke("meow")
|
||||
assert response == AIMessage(content="hello")
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_stream() -> None:
|
||||
"""Test streaming."""
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
AIMessage(content="hello goodbye"),
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
chunks = [chunk for chunk in model.stream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
# Test streaming of additional kwargs.
|
||||
# Relying on insertion order of the additional kwargs dict
|
||||
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24})
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
assert chunks == [
|
||||
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
|
||||
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
|
||||
]
|
||||
|
||||
message = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "move_file",
|
||||
"arguments": '{\n "source_path": "foo",\n "'
|
||||
'destination_path": "bar"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
model = GenericFakeChatModel(messages=cycle([message]))
|
||||
chunks = [chunk async for chunk in model.astream("meow")]
|
||||
|
||||
assert chunks == [
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"name": "move_file"}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '{\n "source_path": "foo"'}
|
||||
},
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="", additional_kwargs={"function_call": {"arguments": ","}}
|
||||
),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
if accumulate_chunks is None:
|
||||
accumulate_chunks = chunk
|
||||
else:
|
||||
accumulate_chunks += chunk
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "move_file",
|
||||
"arguments": '{\n "source_path": "foo",\n "'
|
||||
'destination_path": "bar"\n}',
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def test_generic_fake_chat_model_astream_log() -> None:
|
||||
"""Test streaming."""
|
||||
infinite_cycle = cycle([AIMessage(content="hello goodbye")])
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
log_patches = [
|
||||
log_patch async for log_patch in model.astream_log("meow", diff=False)
|
||||
]
|
||||
final = log_patches[-1]
|
||||
assert final.state["streamed_output"] == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
|
||||
|
||||
async def test_callback_handlers() -> None:
|
||||
"""Verify that model is implemented correctly with handlers working."""
|
||||
|
||||
class MyCustomAsyncHandler(AsyncCallbackHandler):
|
||||
def __init__(self, store: List[str]) -> None:
|
||||
self.store = store
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Do nothing
|
||||
# Required to implement since this is an abstract method
|
||||
pass
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.store.append(token)
|
||||
|
||||
infinite_cycle = cycle(
|
||||
[
|
||||
AIMessage(content="hello goodbye"),
|
||||
]
|
||||
)
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
tokens: List[str] = []
|
||||
# New model
|
||||
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
|
||||
assert results == [
|
||||
AIMessageChunk(content="hello"),
|
||||
AIMessageChunk(content=" "),
|
||||
AIMessageChunk(content="goodbye"),
|
||||
]
|
||||
assert tokens == ["hello", " ", "goodbye"]
|
Loading…
Reference in New Issue
Block a user