# Tool Input Schema

By default, tools infer the argument schema by inspecting the function signature. For more strict requirements, custom input schema can be specified, along with custom validation logic.

In [1]:
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.llms import OpenAI
from langchain.tools.python.tool import PythonREPLTool
from pydantic import BaseModel
from pydantic import Field, root_validator
from typing import Dict, Any

In [2]:
llm = OpenAI(temperature=0)

In [3]:
class ToolInputModel(BaseModel):
    query: str = Field(...)
    
    @root_validator
    def validate_query(cls, values: Dict[str, Any]) -> Dict:
        # Note: this is NOT a safe REPL! This is used for instructive purposes only
        if "import os" in values["query"]:
            raise ValueError("'import os' not permitted in this python REPL.")
        return values
    
tool = PythonREPLTool(args_schema=ToolInputModel)

In [4]:
agent = initialize_agent([tool], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)

In [5]:
# This will succeed, since there aren't any arguments that will be triggered during validation
agent.run("Run a python function that adds 2 and 2")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m I need to define a function that adds two numbers
Action: Python REPL
Action Input: def add_two_numbers(a, b):
    return a + b[0m[36;1m[1;3m[0m[32;1m[1;3m I need to call the function
Action: Python REPL
Action Input: print(add_two_numbers(2, 2))[0m[36;1m[1;3m4
[0m[32;1m[1;3m I now know the final answer
Final Answer: 4[0m

[1m> Finished chain.[0m


'4'

In [6]:
# This will fail, because the attempt to import os will trigger a validation error
agent.run("Run a python function that imports os and lists the dir")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m I need to import os and then list the dir
Action: Python REPL
Action Input: import os; print(os.listdir())[0m

ValidationError: 1 validation error for ToolInputModel
__root__
  'import os' not  (type=value_error)