Harrison/serialize from llm and tools (#760)

This commit is contained in:
Harrison Chase 2023-01-26 23:30:39 -08:00 committed by GitHub
parent 12dc7f26cc
commit e2a7fed890
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 4 deletions

View File

@ -91,14 +91,22 @@ class ConversationalAgent(Agent):
llm: BaseLLM, llm: BaseLLM,
tools: List[Tool], tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
ai_prefix: str = "AI", ai_prefix: str = "AI",
human_prefix: str = "Human", human_prefix: str = "Human",
input_variables: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
prompt = cls.create_prompt( prompt = cls.create_prompt(
tools, ai_prefix=ai_prefix, human_prefix=human_prefix tools,
ai_prefix=ai_prefix,
human_prefix=human_prefix,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
) )
llm_chain = LLMChain( llm_chain = LLMChain(
llm=llm, llm=llm,

View File

@ -54,7 +54,9 @@ def initialize_agent(
llm, tools, callback_manager=callback_manager llm, tools, callback_manager=callback_manager
) )
elif agent_path is not None: elif agent_path is not None:
agent_obj = load_agent(agent_path, callback_manager=callback_manager) agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
)
else: else:
raise ValueError( raise ValueError(
"Somehow both `agent` and `agent_path` are None, " "Somehow both `agent` and `agent_path` are None, "

View File

@ -3,7 +3,7 @@ import json
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, List, Optional, Union
import requests import requests
import yaml import yaml
@ -13,7 +13,9 @@ from langchain.agents.conversational.base import ConversationalAgent
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool
from langchain.chains.loading import load_chain, load_chain_from_config from langchain.chains.loading import load_chain, load_chain_from_config
from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = { AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotAgent, "zero-shot-react-description": ZeroShotAgent,
@ -25,10 +27,42 @@ AGENT_TO_CLASS = {
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
def load_agent_from_config(config: dict, **kwargs: Any) -> Agent: def _load_agent_from_tools(
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any
) -> Agent:
config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS:
raise ValueError(f"Loading {config_type} agent not supported")
if config_type not in AGENT_TO_CLASS:
raise ValueError(f"Loading {config_type} agent not supported")
agent_cls = AGENT_TO_CLASS[config_type]
combined_config = {**config, **kwargs}
return agent_cls.from_llm_and_tools(llm, tools, **combined_config)
def load_agent_from_config(
config: dict,
llm: Optional[BaseLLM] = None,
tools: Optional[List[Tool]] = None,
**kwargs: Any,
) -> Agent:
"""Load agent from Config Dict.""" """Load agent from Config Dict."""
if "_type" not in config: if "_type" not in config:
raise ValueError("Must specify an agent Type in config") raise ValueError("Must specify an agent Type in config")
load_from_tools = config.pop("load_from_llm_and_tools", False)
if load_from_tools:
if llm is None:
raise ValueError(
"If `load_from_llm_and_tools` is set to True, "
"then LLM must be provided"
)
if tools is None:
raise ValueError(
"If `load_from_llm_and_tools` is set to True, "
"then tools must be provided"
)
return _load_agent_from_tools(config, llm, tools, **kwargs)
config_type = config.pop("_type") config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS: if config_type not in AGENT_TO_CLASS:

View File

@ -7,6 +7,8 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -92,6 +94,30 @@ class ZeroShotAgent(Agent):
input_variables = ["input", "agent_scratchpad"] input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables) return PromptTemplate(template=template, input_variables=input_variables)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
input_variables: Optional[List[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
@classmethod @classmethod
def _validate_tools(cls, tools: List[Tool]) -> None: def _validate_tools(cls, tools: List[Tool]) -> None:
for tool in tools: for tool in tools: