mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
come back here if you still want copys
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
from langchain.agents.loading import save_agent
|
||||
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@@ -28,6 +29,7 @@ from langchain.schema import (
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
from langchain.utilities.asyncio import asyncio_timeout
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -121,37 +123,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the agent.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the agent to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# If working with agent executor
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(agent_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(agent_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
return save_agent(file_path, self.dict())
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {}
|
||||
@@ -454,7 +426,6 @@ class Agent(BaseSingleActionAgent):
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -539,7 +510,7 @@ class AgentExecutor(Chain):
|
||||
"""Consists of an agent using tools."""
|
||||
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
|
||||
tools: Sequence[BaseTool]
|
||||
tools: Sequence[BaseStructuredTool]
|
||||
return_intermediate_steps: bool = False
|
||||
max_iterations: Optional[int] = 15
|
||||
max_execution_time: Optional[float] = None
|
||||
@@ -549,7 +520,7 @@ class AgentExecutor(Chain):
|
||||
def from_agent_and_tools(
|
||||
cls,
|
||||
agent: Union[BaseSingleActionAgent, BaseMultiActionAgent],
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
@@ -617,7 +588,7 @@ class AgentExecutor(Chain):
|
||||
else:
|
||||
return self.agent.return_values
|
||||
|
||||
def lookup_tool(self, name: str) -> BaseTool:
|
||||
def lookup_tool(self, name: str) -> BaseStructuredTool:
|
||||
"""Lookup tool by name."""
|
||||
return {tool.name: tool for tool in self.tools}[name]
|
||||
|
||||
@@ -659,7 +630,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
def _take_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
name_to_tool_map: Dict[str, BaseStructuredTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@@ -711,7 +682,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
async def _atake_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
name_to_tool_map: Dict[str, BaseStructuredTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
|
||||
@@ -110,3 +110,34 @@ def _load_agent_from_file(
|
||||
raise ValueError("File type must be json or yaml")
|
||||
# Load the agent from the config now.
|
||||
return load_agent_from_config(config, **kwargs)
|
||||
|
||||
|
||||
def save_agent(file_path: Union[Path, str], agent_dict: dict) -> None:
|
||||
"""Save the agent.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the agent to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# If working with agent executor
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(agent_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(agent_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@@ -8,21 +8,28 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import (
|
||||
StructuredAgentAction,
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseLanguageModel,
|
||||
BaseMessage,
|
||||
BaseOutputParser,
|
||||
StructuredAgentAction,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSingleActionAgent(BaseModel):
|
||||
"""Base Agent class."""
|
||||
class BaseStructuredSingleActionAgent(BaseModel):
|
||||
"""Base Agent for Structured Tool usage class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
@@ -91,7 +98,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
@@ -143,3 +150,218 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {}
|
||||
|
||||
|
||||
class StructuredAgentOutputParser(BaseOutputParser):
|
||||
# TODO: Use a different agent action type
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""Parse text into agent action/finish."""
|
||||
|
||||
|
||||
# FOOBAR
|
||||
class Agent(BaseStructuredSingleActionAgent):
|
||||
"""Class responsible for calling the language model and deciding the action.
|
||||
|
||||
This is driven by an LLMChain. The prompt in the LLMChain MUST include
|
||||
a variable called "agent_scratchpad" where the agent can put its
|
||||
intermediary work.
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
output_parser: StructuredAgentOutputParser
|
||||
allowed_tools: Optional[List[str]] = None
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return self.allowed_tools
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
return ["output"]
|
||||
|
||||
def _fix_text(self, text: str) -> str:
|
||||
"""Fix the text."""
|
||||
raise ValueError("fix_text not implemented for this agent.")
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
return [
|
||||
f"\n{self.observation_prefix.rstrip()}",
|
||||
f"\n\t{self.observation_prefix.rstrip()}",
|
||||
]
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> Union[str, List[BaseMessage]]:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
return thoughts
|
||||
|
||||
def plan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = self.llm_chain.predict(**full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
async def aplan(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
full_output = await self.llm_chain.apredict(**full_inputs)
|
||||
return self.output_parser.parse(full_output)
|
||||
|
||||
def get_full_inputs(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Create the full inputs for the LLMChain from intermediate steps."""
|
||||
thoughts = self._construct_scratchpad(intermediate_steps)
|
||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||
full_inputs = {**kwargs, **new_inputs}
|
||||
return full_inputs
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return list(set(self.llm_chain.input_keys) - {"agent_scratchpad"})
|
||||
|
||||
@root_validator()
|
||||
def validate_prompt(cls, values: Dict) -> Dict:
|
||||
"""Validate that prompt matches format."""
|
||||
prompt = values["llm_chain"].prompt
|
||||
if "agent_scratchpad" not in prompt.input_variables:
|
||||
logger.warning(
|
||||
"`agent_scratchpad` should be a variable in prompt.input_variables."
|
||||
" Did not find it, so adding it at the end."
|
||||
)
|
||||
prompt.input_variables.append("agent_scratchpad")
|
||||
if isinstance(prompt, PromptTemplate):
|
||||
prompt.template += "\n{agent_scratchpad}"
|
||||
elif isinstance(prompt, FewShotPromptTemplate):
|
||||
prompt.suffix += "\n{agent_scratchpad}"
|
||||
else:
|
||||
raise ValueError(f"Got unexpected prompt type {type(prompt)}")
|
||||
return values
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def observation_prefix(self) -> str:
|
||||
"""Prefix to append the observation with."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def llm_prefix(self) -> str:
|
||||
"""Prefix to append the LLM call with."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def create_prompt(cls, tools: Sequence[BaseStructuredTool]) -> BasePromptTemplate:
|
||||
"""Create a prompt for this class."""
|
||||
|
||||
@classmethod
|
||||
def _validate_tools(cls, tools: Sequence[BaseStructuredTool]) -> None:
|
||||
"""Validate that appropriate tools are passed in."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> StructuredAgentOutputParser:
|
||||
"""Get default output parser for this class."""
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseStructuredTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[StructuredAgentOutputParser] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=cls.create_prompt(tools),
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser or cls._get_default_output_parser()
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations."""
|
||||
if early_stopping_method == "force":
|
||||
# `force` just returns a constant string
|
||||
return AgentFinish(
|
||||
{"output": "Agent stopped due to iteration limit or time limit."}, ""
|
||||
)
|
||||
elif early_stopping_method == "generate":
|
||||
# Generate does one final forward pass
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += (
|
||||
f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
)
|
||||
# Adding to the previous steps, we now tell the LLM to make a final pred
|
||||
thoughts += (
|
||||
"\n\nI now need to return a final answer based on the previous steps:"
|
||||
)
|
||||
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
||||
full_inputs = {**kwargs, **new_inputs}
|
||||
full_output = self.llm_chain.predict(**full_inputs)
|
||||
# We try to extract a final answer
|
||||
parsed_output = self.output_parser.parse(full_output)
|
||||
if isinstance(parsed_output, AgentFinish):
|
||||
# If we can extract, we send the correct stuff
|
||||
return parsed_output
|
||||
else:
|
||||
# If we can extract, but the tool is not the final tool,
|
||||
# we just return the full output
|
||||
return AgentFinish({"output": full_output}, full_output)
|
||||
else:
|
||||
raise ValueError(
|
||||
"early_stopping_method should be one of `force` or `generate`, "
|
||||
f"got {early_stopping_method}"
|
||||
)
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
return {
|
||||
"llm_prefix": self.llm_prefix,
|
||||
"observation_prefix": self.observation_prefix,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user