mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 11:12:47 +00:00
cr
This commit is contained in:
parent
ac208f85c8
commit
f646c94bc1
@ -1,5 +1,5 @@
|
|||||||
"""Routing chains."""
|
"""Routing chains."""
|
||||||
from langchain.agents.agent import Agent
|
from langchain.agents.agent import AgentWithTools
|
||||||
from langchain.agents.loading import initialize_agent
|
from langchain.agents.loading import initialize_agent
|
||||||
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
|
||||||
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
|
||||||
@ -10,7 +10,7 @@ __all__ = [
|
|||||||
"MRKLChain",
|
"MRKLChain",
|
||||||
"SelfAskWithSearchChain",
|
"SelfAskWithSearchChain",
|
||||||
"ReActChain",
|
"ReActChain",
|
||||||
"Agent",
|
"AgentWithTools",
|
||||||
"Tool",
|
"Tool",
|
||||||
"initialize_agent",
|
"initialize_agent",
|
||||||
"ZeroShotAgent",
|
"ZeroShotAgent",
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
"""Chain that takes in an input and produces an action and action input."""
|
"""Chain that takes in an input and produces an action and action input."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.agents.input import ChainedInput
|
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -17,7 +16,7 @@ from langchain.prompts.base import BasePromptTemplate
|
|||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
|
|
||||||
|
|
||||||
class Planner(BaseModel):
|
class Agent(BaseModel):
|
||||||
"""Class responsible for calling the language model and deciding the action.
|
"""Class responsible for calling the language model and deciding the action.
|
||||||
|
|
||||||
This is driven by an LLMChain. The prompt in the LLMChain MUST include
|
This is driven by an LLMChain. The prompt in the LLMChain MUST include
|
||||||
@ -72,6 +71,7 @@ class Planner(BaseModel):
|
|||||||
return AgentAction(tool, tool_input, full_output)
|
return AgentAction(tool, tool_input, full_output)
|
||||||
|
|
||||||
def prepare_for_new_call(self) -> None:
|
def prepare_for_new_call(self) -> None:
|
||||||
|
"""Prepare the agent for new call, if needed."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -107,8 +107,8 @@ class Planner(BaseModel):
|
|||||||
def llm_prefix(self) -> str:
|
def llm_prefix(self) -> str:
|
||||||
"""Prefix to append the LLM call with."""
|
"""Prefix to append the LLM call with."""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
||||||
"""Create a prompt for this class."""
|
"""Create a prompt for this class."""
|
||||||
|
|
||||||
@ -118,16 +118,17 @@ class Planner(BaseModel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Planner:
|
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool]) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
cls._validate_tools(tools)
|
cls._validate_tools(tools)
|
||||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
||||||
return cls(llm_chain=llm_chain)
|
return cls(llm_chain=llm_chain)
|
||||||
|
|
||||||
|
|
||||||
class Agent(Chain, BaseModel):
|
class AgentWithTools(Chain, BaseModel):
|
||||||
|
"""Consists of an agent using tools."""
|
||||||
|
|
||||||
planner: Planner
|
agent: Agent
|
||||||
tools: List[Tool]
|
tools: List[Tool]
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
|
|
||||||
@ -137,7 +138,7 @@ class Agent(Chain, BaseModel):
|
|||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
return self.planner.input_keys
|
return self.agent.input_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
@ -146,26 +147,25 @@ class Agent(Chain, BaseModel):
|
|||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return self.planner.return_values + ["intermediate_steps"]
|
return self.agent.return_values + ["intermediate_steps"]
|
||||||
else:
|
else:
|
||||||
return self.planner.return_values
|
return self.agent.return_values
|
||||||
|
|
||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||||
"""Run text through and get agent response."""
|
"""Run text through and get agent response."""
|
||||||
# Do any preparation necessary when receiving a new input.
|
# Do any preparation necessary when receiving a new input.
|
||||||
self.planner.prepare_for_new_call()
|
self.agent.prepare_for_new_call()
|
||||||
# Construct a mapping of tool name to tool for easy lookup
|
# Construct a mapping of tool name to tool for easy lookup
|
||||||
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
|
name_to_tool_map = {tool.name: tool.func for tool in self.tools}
|
||||||
# We construct a mapping from each tool to a color, used for logging.
|
# We construct a mapping from each tool to a color, used for logging.
|
||||||
color_mapping = get_color_mapping(
|
color_mapping = get_color_mapping(
|
||||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||||
)
|
)
|
||||||
planner_inputs = inputs.copy()
|
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||||
# We now enter the agent loop (until it returns something).
|
# We now enter the agent loop (until it returns something).
|
||||||
while True:
|
while True:
|
||||||
# Call the LLM to see what to do.
|
# Call the LLM to see what to do.
|
||||||
output = self.planner.plan(intermediate_steps, **planner_inputs)
|
output = self.agent.plan(intermediate_steps, **inputs)
|
||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
"""Input manager for agents."""
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.schema import AgentAction
|
|
||||||
|
|
||||||
|
|
||||||
class ChainedInput:
|
|
||||||
"""Class for working with input that is the result of chains."""
|
|
||||||
|
|
||||||
def __init__(self, text: str, verbose: bool = False):
|
|
||||||
"""Initialize with verbose flag and initial text."""
|
|
||||||
self._verbose = verbose
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_start(text)
|
|
||||||
self._input = text
|
|
||||||
self._intermediate_actions: List[AgentAction] = []
|
|
||||||
self._intermediate_observations: List[str] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def intermediate_steps(self) -> List:
|
|
||||||
"""Return intermediate steps the agent took."""
|
|
||||||
steps = []
|
|
||||||
for i, action in enumerate(self._intermediate_actions):
|
|
||||||
step = {
|
|
||||||
"log": action.log,
|
|
||||||
"tool": action.tool,
|
|
||||||
"tool_input": action.tool_input,
|
|
||||||
"observation": self._intermediate_observations[i],
|
|
||||||
}
|
|
||||||
steps.append(step)
|
|
||||||
return steps
|
|
||||||
|
|
||||||
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
|
|
||||||
"""Add text to input, print if in verbose mode."""
|
|
||||||
|
|
||||||
self._input += action.log
|
|
||||||
self._intermediate_actions.append(action)
|
|
||||||
|
|
||||||
def add_observation(
|
|
||||||
self,
|
|
||||||
observation: str,
|
|
||||||
observation_prefix: str,
|
|
||||||
llm_prefix: str,
|
|
||||||
color: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Add observation to input, print if in verbose mode."""
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_observation(
|
|
||||||
observation,
|
|
||||||
color=color,
|
|
||||||
observation_prefix=observation_prefix,
|
|
||||||
llm_prefix=llm_prefix,
|
|
||||||
)
|
|
||||||
self._input += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
|
||||||
self._intermediate_observations.append(observation)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input(self) -> str:
|
|
||||||
"""Return the accumulated input."""
|
|
||||||
return self._input
|
|
@ -1,17 +1,17 @@
|
|||||||
"""Load agent."""
|
"""Load agent."""
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from langchain.agents.agent import Agent, Planner
|
from langchain.agents.agent import AgentWithTools
|
||||||
from langchain.agents.mrkl.base import ZeroShotPlanner
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
from langchain.agents.react.base import ReActDocstorePlanner
|
from langchain.agents.react.base import ReActDocstoreAgent
|
||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchPlanner
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
AGENT_TO_CLASS = {
|
AGENT_TO_CLASS = {
|
||||||
"zero-shot-react-description": ZeroShotPlanner,
|
"zero-shot-react-description": ZeroShotAgent,
|
||||||
"react-docstore": ReActDocstorePlanner,
|
"react-docstore": ReActDocstoreAgent,
|
||||||
"self-ask-with-search": SelfAskWithSearchPlanner,
|
"self-ask-with-search": SelfAskWithSearchAgent,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ def initialize_agent(
|
|||||||
llm: LLM,
|
llm: LLM,
|
||||||
agent: str = "zero-shot-react-description",
|
agent: str = "zero-shot-react-description",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> AgentWithTools:
|
||||||
"""Load agent given tools and LLM.
|
"""Load agent given tools and LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -39,5 +39,5 @@ def initialize_agent(
|
|||||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||||
)
|
)
|
||||||
agent_cls = AGENT_TO_CLASS[agent]
|
agent_cls = AGENT_TO_CLASS[agent]
|
||||||
planner = agent_cls.from_llm_and_tools(llm, tools, **kwargs)
|
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
|
||||||
return Agent(planner=planner, tools=tools)
|
return AgentWithTools(agent=agent_obj, tools=tools, **kwargs)
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
from typing import Any, Callable, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import Agent, Planner
|
from langchain.agents.agent import Agent, AgentWithTools
|
||||||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -47,7 +47,7 @@ def get_action_and_input(llm_output: str) -> Tuple[str, str]:
|
|||||||
return action, action_input.strip(" ").strip('"')
|
return action, action_input.strip(" ").strip('"')
|
||||||
|
|
||||||
|
|
||||||
class ZeroShotPlanner(Planner):
|
class ZeroShotAgent(Agent):
|
||||||
"""Agent for the MRKL chain."""
|
"""Agent for the MRKL chain."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -101,10 +101,7 @@ class ZeroShotPlanner(Planner):
|
|||||||
return get_action_and_input(text)
|
return get_action_and_input(text)
|
||||||
|
|
||||||
|
|
||||||
ZeroShotAgent = ZeroShotPlanner
|
class MRKLChain(AgentWithTools):
|
||||||
|
|
||||||
|
|
||||||
class MRKLChain(Agent):
|
|
||||||
"""Chain that implements the MRKL system.
|
"""Chain that implements the MRKL system.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -119,7 +116,9 @@ class MRKLChain(Agent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent:
|
def from_chains(
|
||||||
|
cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any
|
||||||
|
) -> AgentWithTools:
|
||||||
"""User friendly way to initialize the MRKL chain.
|
"""User friendly way to initialize the MRKL chain.
|
||||||
|
|
||||||
This is intended to be an easy way to get up and running with the
|
This is intended to be an easy way to get up and running with the
|
||||||
@ -159,5 +158,5 @@ class MRKLChain(Agent):
|
|||||||
Tool(name=c.action_name, func=c.action, description=c.action_description)
|
Tool(name=c.action_name, func=c.action, description=c.action_description)
|
||||||
for c in chains
|
for c in chains
|
||||||
]
|
]
|
||||||
planner = ZeroShotPlanner.from_llm_and_tools(llm, tools)
|
agent = ZeroShotAgent.from_llm_and_tools(llm, tools)
|
||||||
return cls(planner=planner, tools=tools, **kwargs)
|
return cls(agent=agent, tools=tools, **kwargs)
|
||||||
|
@ -1,21 +1,20 @@
|
|||||||
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
|
"""Chain that implements the ReAct paper from https://arxiv.org/pdf/2210.03629.pdf."""
|
||||||
import re
|
import re
|
||||||
from typing import Any, ClassVar, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.agents.agent import Agent, Planner
|
from langchain.agents.agent import Agent, AgentWithTools
|
||||||
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
|
||||||
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.docstore.base import Docstore
|
from langchain.docstore.base import Docstore
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class ReActDocstorePlanner(Planner, BaseModel):
|
class ReActDocstoreAgent(Agent, BaseModel):
|
||||||
"""Agent for the ReAct chin."""
|
"""Agent for the ReAct chin."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -75,9 +74,6 @@ class ReActDocstorePlanner(Planner, BaseModel):
|
|||||||
return f"Thought {self.i}:"
|
return f"Thought {self.i}:"
|
||||||
|
|
||||||
|
|
||||||
ReActDocstoreAgent = ReActDocstorePlanner
|
|
||||||
|
|
||||||
|
|
||||||
class DocstoreExplorer:
|
class DocstoreExplorer:
|
||||||
"""Class to assist with exploration of a document store."""
|
"""Class to assist with exploration of a document store."""
|
||||||
|
|
||||||
@ -103,7 +99,7 @@ class DocstoreExplorer:
|
|||||||
return self.document.lookup(term)
|
return self.document.lookup(term)
|
||||||
|
|
||||||
|
|
||||||
class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
|
class ReActTextWorldAgent(ReActDocstoreAgent, BaseModel):
|
||||||
"""Agent for the ReAct TextWorld chain."""
|
"""Agent for the ReAct TextWorld chain."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -120,10 +116,7 @@ class ReActTextWorldPlanner(ReActDocstorePlanner, BaseModel):
|
|||||||
raise ValueError(f"Tool name should be Play, got {tool_names}")
|
raise ValueError(f"Tool name should be Play, got {tool_names}")
|
||||||
|
|
||||||
|
|
||||||
ReActTextWorldAgent = ReActTextWorldPlanner
|
class ReActChain(AgentWithTools):
|
||||||
|
|
||||||
|
|
||||||
class ReActChain(Agent):
|
|
||||||
"""Chain that implements the ReAct paper.
|
"""Chain that implements the ReAct paper.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -140,5 +133,5 @@ class ReActChain(Agent):
|
|||||||
Tool(name="Search", func=docstore_explorer.search),
|
Tool(name="Search", func=docstore_explorer.search),
|
||||||
Tool(name="Lookup", func=docstore_explorer.lookup),
|
Tool(name="Lookup", func=docstore_explorer.lookup),
|
||||||
]
|
]
|
||||||
planner = ReActDocstorePlanner.from_llm_and_tools(llm, tools)
|
agent = ReActDocstoreAgent.from_llm_and_tools(llm, tools)
|
||||||
super().__init__(planner=planner, tools=tools, **kwargs)
|
super().__init__(agent=agent, tools=tools, **kwargs)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Chain that does self ask with search."""
|
"""Chain that does self ask with search."""
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import Agent, Planner
|
from langchain.agents.agent import Agent, AgentWithTools
|
||||||
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
||||||
from langchain.agents.tools import Tool
|
from langchain.agents.tools import Tool
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
@ -9,7 +9,7 @@ from langchain.prompts.base import BasePromptTemplate
|
|||||||
from langchain.serpapi import SerpAPIWrapper
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
class SelfAskWithSearchPlanner(Planner):
|
class SelfAskWithSearchAgent(Agent):
|
||||||
"""Agent for the self-ask-with-search paper."""
|
"""Agent for the self-ask-with-search paper."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -63,10 +63,7 @@ class SelfAskWithSearchPlanner(Planner):
|
|||||||
return "Are follow up questions needed here:"
|
return "Are follow up questions needed here:"
|
||||||
|
|
||||||
|
|
||||||
SelfAskWithSearchAgent = SelfAskWithSearchPlanner
|
class SelfAskWithSearchChain(AgentWithTools):
|
||||||
|
|
||||||
|
|
||||||
class SelfAskWithSearchChain(Agent):
|
|
||||||
"""Chain that does self ask with search.
|
"""Chain that does self ask with search.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -80,5 +77,5 @@ class SelfAskWithSearchChain(Agent):
|
|||||||
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
||||||
"""Initialize with just an LLM and a search chain."""
|
"""Initialize with just an LLM and a search chain."""
|
||||||
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
||||||
planner = SelfAskWithSearchPlanner.from_llm_and_tools(llm, [search_tool])
|
agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool])
|
||||||
super().__init__(planner=planner, tools=[search_tool], **kwargs)
|
super().__init__(agent=agent, tools=[search_tool], **kwargs)
|
||||||
|
@ -1,75 +0,0 @@
|
|||||||
"""Test input manipulating logic."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from io import StringIO
|
|
||||||
|
|
||||||
from langchain.agents.input import ChainedInput
|
|
||||||
from langchain.input import get_color_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_input_not_verbose() -> None:
|
|
||||||
"""Test chained input logic."""
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input = ChainedInput("foo")
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == ""
|
|
||||||
assert chained_input.input == "foo"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("bar", "1", "2", None)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == ""
|
|
||||||
assert chained_input.input == "foo\n1bar\n2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_input_verbose() -> None:
|
|
||||||
"""Test chained input logic, making sure verbose doesn't mess it up."""
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input = ChainedInput("foo", verbose=True)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "foo"
|
|
||||||
assert chained_input.input == "foo"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("bar", "1", "2", None)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "\n1bar\n2"
|
|
||||||
assert chained_input.input == "foo\n1bar\n2"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("baz", "3", "4", "blue")
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "\n3\x1b[36;1m\x1b[1;3mbaz\x1b[0m\n4"
|
|
||||||
assert chained_input.input == "foo\n1bar\n2\n3baz\n4"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_color_mapping() -> None:
|
|
||||||
"""Test getting of color mapping."""
|
|
||||||
# Test on few inputs.
|
|
||||||
items = ["foo", "bar"]
|
|
||||||
output = get_color_mapping(items)
|
|
||||||
expected_output = {"foo": "blue", "bar": "yellow"}
|
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
# Test on a lot of inputs.
|
|
||||||
items = [f"foo-{i}" for i in range(20)]
|
|
||||||
output = get_color_mapping(items)
|
|
||||||
assert len(output) == 20
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_color_mapping_excluded_colors() -> None:
|
|
||||||
"""Test getting of color mapping with excluded colors."""
|
|
||||||
items = ["foo", "bar"]
|
|
||||||
output = get_color_mapping(items, excluded_colors=["blue"])
|
|
||||||
expected_output = {"foo": "yellow", "bar": "pink"}
|
|
||||||
assert output == expected_output
|
|
Loading…
Reference in New Issue
Block a user