Adding missing types in some pydantic models (#9355)

* Adding missing types in some pydantic models -- this change is
required for making the code work with pydantic v2.
This commit is contained in:
Eugene Yurtsev 2023-08-16 23:10:34 -04:00 committed by GitHub
parent 1c089cadd7
commit 4c2de2a7f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 56 additions and 54 deletions

View File

@ -615,9 +615,9 @@ class Agent(BaseSingleActionAgent):
class ExceptionTool(BaseTool): class ExceptionTool(BaseTool):
"""Tool that just returns the query.""" """Tool that just returns the query."""
name = "_Exception" name: str = "_Exception"
"""Name of the tool.""" """Name of the tool."""
description = "Exception tool" description: str = "Exception tool"
"""Description of the tool.""" """Description of the tool."""
def _run( def _run(

View File

@ -182,7 +182,7 @@ class EmbeddingsClusteringFilter(BaseDocumentTransformer, BaseModel):
"""By default results are re-ordered "grouping" them by cluster, if sorted is true """By default results are re-ordered "grouping" them by cluster, if sorted is true
result will be ordered by the original position from the retriever""" result will be ordered by the original position from the retriever"""
remove_duplicates = False remove_duplicates: bool = False
""" By default duplicated results are skipped and replaced by the next closest """ By default duplicated results are skipped and replaced by the next closest
vector in the cluster. If remove_duplicates is true no replacement will be done: vector in the cluster. If remove_duplicates is true no replacement will be done:
This could dramatically reduce results when there is a lot of overlap between This could dramatically reduce results when there is a lot of overlap between

View File

@ -37,9 +37,9 @@ class RocksetChatMessageHistory(BaseChatMessageHistory):
# These values are configured for the typical # These values are configured for the typical
# free VI. Read more about VIs here: # free VI. Read more about VIs here:
# https://rockset.com/docs/instances # https://rockset.com/docs/instances
SLEEP_INTERVAL_MS = 5 SLEEP_INTERVAL_MS: int = 5
ADD_TIMEOUT_MS = 5000 ADD_TIMEOUT_MS: int = 5000
CREATE_TIMEOUT_MS = 20000 CREATE_TIMEOUT_MS: int = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None: def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""Sleeps until meth() evaluates to true. Passes kwargs into """Sleeps until meth() evaluates to true. Passes kwargs into

View File

@ -13,8 +13,8 @@ class MotorheadMemory(BaseChatMemory):
"""Chat message memory backed by Motorhead service.""" """Chat message memory backed by Motorhead service."""
url: str = MANAGED_URL url: str = MANAGED_URL
timeout = 3000 timeout: int = 3000
memory_key = "history" memory_key: str = "history"
session_id: str session_id: str
context: Optional[str] = None context: Optional[str] = None

View File

@ -18,8 +18,8 @@ class GooglePlacesSchema(BaseModel):
class GooglePlacesTool(BaseTool): class GooglePlacesTool(BaseTool):
"""Tool that queries the Google places API.""" """Tool that queries the Google places API."""
name = "google_places" name: str = "google_places"
description = ( description: str = (
"A wrapper around Google Places. " "A wrapper around Google Places. "
"Useful for when you need to validate or " "Useful for when you need to validate or "
"discover addressed from ambiguous text. " "discover addressed from ambiguous text. "

View File

@ -84,8 +84,8 @@ class JsonSpec(BaseModel):
class JsonListKeysTool(BaseTool): class JsonListKeysTool(BaseTool):
"""Tool for listing keys in a JSON spec.""" """Tool for listing keys in a JSON spec."""
name = "json_spec_list_keys" name: str = "json_spec_list_keys"
description = """ description: str = """
Can be used to list all keys at a given path. Can be used to list all keys at a given path.
Before calling this you should be SURE that the path to this exists. Before calling this you should be SURE that the path to this exists.
The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]). The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]).
@ -110,8 +110,8 @@ class JsonListKeysTool(BaseTool):
class JsonGetValueTool(BaseTool): class JsonGetValueTool(BaseTool):
"""Tool for getting a value in a JSON spec.""" """Tool for getting a value in a JSON spec."""
name = "json_spec_get_value" name: str = "json_spec_get_value"
description = """ description: str = """
Can be used to see value in string format at a given path. Can be used to see value in string format at a given path.
Before calling this you should be SURE that the path to this exists. Before calling this you should be SURE that the path to this exists.
The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]). The input is a text representation of the path to the dict in Python syntax (e.g. data["key1"][0]["key2"]).

View File

@ -58,8 +58,8 @@ class NUASchema(BaseModel):
class NucliaUnderstandingAPI(BaseTool): class NucliaUnderstandingAPI(BaseTool):
"""Tool to process files with the Nuclia Understanding API.""" """Tool to process files with the Nuclia Understanding API."""
name = "nuclia_understanding_api" name: str = "nuclia_understanding_api"
description = ( description: str = (
"A wrapper around Nuclia Understanding API endpoints. " "A wrapper around Nuclia Understanding API endpoints. "
"Useful for when you need to extract text from any kind of files. " "Useful for when you need to extract text from any kind of files. "
) )

View File

@ -32,8 +32,8 @@ class BaseRequestsTool(BaseModel):
class RequestsGetTool(BaseRequestsTool, BaseTool): class RequestsGetTool(BaseRequestsTool, BaseTool):
"""Tool for making a GET request to an API endpoint.""" """Tool for making a GET request to an API endpoint."""
name = "requests_get" name: str = "requests_get"
description = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request." description: str = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request."
def _run( def _run(
self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None
@ -53,8 +53,8 @@ class RequestsGetTool(BaseRequestsTool, BaseTool):
class RequestsPostTool(BaseRequestsTool, BaseTool): class RequestsPostTool(BaseRequestsTool, BaseTool):
"""Tool for making a POST request to an API endpoint.""" """Tool for making a POST request to an API endpoint."""
name = "requests_post" name: str = "requests_post"
description = """Use this when you want to POST to a website. description: str = """Use this when you want to POST to a website.
Input should be a json string with two keys: "url" and "data". Input should be a json string with two keys: "url" and "data".
The value of "url" should be a string, and the value of "data" should be a dictionary of The value of "url" should be a string, and the value of "data" should be a dictionary of
key-value pairs you want to POST to the url. key-value pairs you want to POST to the url.
@ -90,8 +90,8 @@ class RequestsPostTool(BaseRequestsTool, BaseTool):
class RequestsPatchTool(BaseRequestsTool, BaseTool): class RequestsPatchTool(BaseRequestsTool, BaseTool):
"""Tool for making a PATCH request to an API endpoint.""" """Tool for making a PATCH request to an API endpoint."""
name = "requests_patch" name: str = "requests_patch"
description = """Use this when you want to PATCH to a website. description: str = """Use this when you want to PATCH to a website.
Input should be a json string with two keys: "url" and "data". Input should be a json string with two keys: "url" and "data".
The value of "url" should be a string, and the value of "data" should be a dictionary of The value of "url" should be a string, and the value of "data" should be a dictionary of
key-value pairs you want to PATCH to the url. key-value pairs you want to PATCH to the url.
@ -127,8 +127,8 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool):
class RequestsPutTool(BaseRequestsTool, BaseTool): class RequestsPutTool(BaseRequestsTool, BaseTool):
"""Tool for making a PUT request to an API endpoint.""" """Tool for making a PUT request to an API endpoint."""
name = "requests_put" name: str = "requests_put"
description = """Use this when you want to PUT to a website. description: str = """Use this when you want to PUT to a website.
Input should be a json string with two keys: "url" and "data". Input should be a json string with two keys: "url" and "data".
The value of "url" should be a string, and the value of "data" should be a dictionary of The value of "url" should be a string, and the value of "data" should be a dictionary of
key-value pairs you want to PUT to the url. key-value pairs you want to PUT to the url.
@ -164,8 +164,8 @@ class RequestsPutTool(BaseRequestsTool, BaseTool):
class RequestsDeleteTool(BaseRequestsTool, BaseTool): class RequestsDeleteTool(BaseRequestsTool, BaseTool):
"""Tool for making a DELETE request to an API endpoint.""" """Tool for making a DELETE request to an API endpoint."""
name = "requests_delete" name: str = "requests_delete"
description = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request." description: str = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request."
def _run( def _run(
self, self,

View File

@ -17,8 +17,8 @@ class SceneXplainInput(BaseModel):
class SceneXplainTool(BaseTool): class SceneXplainTool(BaseTool):
"""Tool that explains images.""" """Tool that explains images."""
name = "image_explainer" name: str = "image_explainer"
description = ( description: str = (
"An Image Captioning Tool: Use this tool to generate a detailed caption " "An Image Captioning Tool: Use this tool to generate a detailed caption "
"for an image. The input can be an image file of any format, and " "for an image. The input can be an image file of any format, and "
"the output will be a text description that covers every detail of the image." "the output will be a text description that covers every detail of the image."

View File

@ -21,9 +21,9 @@ class SleepInput(BaseModel):
class SleepTool(BaseTool): class SleepTool(BaseTool):
"""Tool that adds the capability to sleep.""" """Tool that adds the capability to sleep."""
name = "sleep" name: str = "sleep"
args_schema: Type[BaseModel] = SleepInput args_schema: Type[BaseModel] = SleepInput
description = "Make agent sleep for a specified number of seconds." description: str = "Make agent sleep for a specified number of seconds."
def _run( def _run(
self, self,

View File

@ -33,8 +33,8 @@ class BaseSparkSQLTool(BaseModel):
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for querying a Spark SQL.""" """Tool for querying a Spark SQL."""
name = "query_sql_db" name: str = "query_sql_db"
description = """ description: str = """
Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL. Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL.
If the query is not correct, an error message will be returned. If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again. If an error is returned, rewrite the query, check the query, and try again.
@ -52,8 +52,8 @@ class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool):
class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting metadata about a Spark SQL.""" """Tool for getting metadata about a Spark SQL."""
name = "schema_sql_db" name: str = "schema_sql_db"
description = """ description: str = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first! Be sure that the tables actually exist by calling list_tables_sql_db first!
@ -72,8 +72,8 @@ class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool):
class ListSparkSQLTool(BaseSparkSQLTool, BaseTool): class ListSparkSQLTool(BaseSparkSQLTool, BaseTool):
"""Tool for getting tables names.""" """Tool for getting tables names."""
name = "list_tables_sql_db" name: str = "list_tables_sql_db"
description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."
def _run( def _run(
self, self,
@ -91,8 +91,8 @@ class QueryCheckerTool(BaseSparkSQLTool, BaseTool):
template: str = QUERY_CHECKER template: str = QUERY_CHECKER
llm: BaseLanguageModel llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False) llm_chain: LLMChain = Field(init=False)
name = "query_checker_sql_db" name: str = "query_checker_sql_db"
description = """ description: str = """
Use this tool to double check if your query is correct before executing it. Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db! Always use this tool before executing a query with query_sql_db!
""" """

View File

@ -33,8 +33,8 @@ class BaseSQLDatabaseTool(BaseModel):
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database.""" """Tool for querying a SQL database."""
name = "sql_db_query" name: str = "sql_db_query"
description = """ description: str = """
Input to this tool is a detailed and correct SQL query, output is a result from the database. Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned. If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again. If an error is returned, rewrite the query, check the query, and try again.
@ -52,8 +52,8 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database.""" """Tool for getting metadata about a SQL database."""
name = "sql_db_schema" name: str = "sql_db_schema"
description = """ description: str = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Example Input: "table1, table2, table3" Example Input: "table1, table2, table3"
@ -71,8 +71,8 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names.""" """Tool for getting tables names."""
name = "sql_db_list_tables" name: str = "sql_db_list_tables"
description = "Input is an empty string, output is a comma separated list of tables in the database." description: str = "Input is an empty string, output is a comma separated list of tables in the database."
def _run( def _run(
self, self,
@ -90,8 +90,8 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
template: str = QUERY_CHECKER template: str = QUERY_CHECKER
llm: BaseLanguageModel llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False) llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker" name: str = "sql_db_query_checker"
description = """ description: str = """
Use this tool to double check if your query is correct before executing it. Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db! Always use this tool before executing a query with query_sql_db!
""" """

View File

@ -50,7 +50,7 @@ class ArxivAPIWrapper(BaseModel):
arxiv_search: Any #: :meta private: arxiv_search: Any #: :meta private:
arxiv_exceptions: Any # :meta private: arxiv_exceptions: Any # :meta private:
top_k_results: int = 3 top_k_results: int = 3
ARXIV_MAX_QUERY_LENGTH = 300 ARXIV_MAX_QUERY_LENGTH: int = 300
load_max_docs: int = 100 load_max_docs: int = 100
load_all_available_meta: bool = False load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000 doc_content_chars_max: Optional[int] = 4000

View File

@ -14,7 +14,7 @@ class BraveSearchWrapper(BaseModel):
"""The API key to use for the Brave search engine.""" """The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request.""" """Additional keyword arguments to pass to the search request."""
base_url = "https://api.search.brave.com/res/v1/web/search" base_url: str = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine.""" """The base URL for the Brave search engine."""
def run(self, query: str) -> str: def run(self, query: str) -> str:

View File

@ -36,14 +36,16 @@ class PubMedAPIWrapper(BaseModel):
parse: Any #: :meta private: parse: Any #: :meta private:
base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?" base_url_esearch: str = (
base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?" "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?"
max_retry = 5 )
sleep_time = 0.2 base_url_efetch: str = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
max_retry: int = 5
sleep_time: float = 0.2
# Default values for the parameters # Default values for the parameters
top_k_results: int = 3 top_k_results: int = 3
MAX_QUERY_LENGTH = 300 MAX_QUERY_LENGTH: int = 300
doc_content_chars_max: int = 2000 doc_content_chars_max: int = 2000
email: str = "your_email@example.com" email: str = "your_email@example.com"

View File

@ -144,7 +144,7 @@ def _get_default_params() -> dict:
class SearxResults(dict): class SearxResults(dict):
"""Dict like wrapper around search api results.""" """Dict like wrapper around search api results."""
_data = "" _data: str = ""
def __init__(self, data: str): def __init__(self, data: str):
"""Take a raw result from Searx and make it into a dict like object.""" """Take a raw result from Searx and make it into a dict like object."""