Pass parsed inputs through to tool _run (#4309)

This commit is contained in:
Zander Chase
2023-05-08 09:13:05 -07:00
committed by GitHub
parent 35c9e6ab40
commit 8b284f9ad0
18 changed files with 1067 additions and 11 deletions

View File

@@ -10,6 +10,13 @@ from langchain.tools.file_management.list_dir import ListDirectoryTool
from langchain.tools.file_management.move import MoveFileTool
from langchain.tools.file_management.read import ReadFileTool
from langchain.tools.file_management.write import WriteFileTool
from langchain.tools.gmail import (
GmailCreateDraft,
GmailGetMessage,
GmailGetThread,
GmailSearch,
GmailSendMessage,
)
from langchain.tools.google_places.tool import GooglePlacesTool
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
from langchain.tools.google_serper.tool import GoogleSerperResults, GoogleSerperRun
@@ -56,6 +63,11 @@ __all__ = [
"ExtractTextTool",
"FileSearchTool",
"GetElementsTool",
"GmailCreateDraft",
"GmailGetMessage",
"GmailGetThread",
"GmailSearch",
"GmailSendMessage",
"GooglePlacesTool",
"GoogleSearchResults",
"GoogleSearchRun",

View File

@@ -160,16 +160,19 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> None:
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
input_args.validate(tool_input)
result = input_args.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
@@ -224,7 +227,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs: Any,
) -> Any:
"""Run the tool."""
self._parse_input(tool_input)
parsed_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
@@ -241,7 +244,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs,
)
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
self._run(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported
@@ -263,7 +266,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
self._parse_input(tool_input)
parsed_input = self._parse_input(tool_input)
if not self.verbose and verbose is not None:
verbose_ = verbose
else:
@@ -280,7 +283,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
)
try:
# We then call the tool on the tool input to get an observation
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = (
await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs)
if new_arg_supported

View File

@@ -0,0 +1,17 @@
"""Gmail tools."""
from langchain.tools.gmail.create_draft import GmailCreateDraft
from langchain.tools.gmail.get_message import GmailGetMessage
from langchain.tools.gmail.get_thread import GmailGetThread
from langchain.tools.gmail.search import GmailSearch
from langchain.tools.gmail.send_message import GmailSendMessage
from langchain.tools.gmail.utils import get_gmail_credentials
__all__ = [
"GmailCreateDraft",
"GmailSendMessage",
"GmailSearch",
"GmailGetMessage",
"GmailGetThread",
"get_gmail_credentials",
]

View File

@@ -0,0 +1,27 @@
"""Base class for Gmail tools."""
from __future__ import annotations
from typing import TYPE_CHECKING
from pydantic import Field
from langchain.tools.base import BaseTool
from langchain.tools.gmail.utils import build_resource_service
if TYPE_CHECKING:
# This is for linting and IDE typehints
from googleapiclient.discovery import Resource
else:
try:
# We do this so pydantic can resolve the types when instantiating
from googleapiclient.discovery import Resource
except ImportError:
pass
class GmailBaseTool(BaseTool):
api_resource: Resource = Field(default_factory=build_resource_service)
@classmethod
def from_api_resource(cls, api_resource: Resource) -> "GmailBaseTool":
return cls(service=api_resource)

View File

@@ -0,0 +1,97 @@
import base64
from email.message import EmailMessage
from typing import List, Optional, Type
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.gmail.base import GmailBaseTool
class CreateDraftSchema(BaseModel):
message: str = Field(
...,
description="The message to include in the draft.",
)
to: List[str] = Field(
...,
description="The list of recipients.",
)
subject: str = Field(
...,
description="The subject of the message.",
)
cc: Optional[List[str]] = Field(
None,
description="The list of CC recipients.",
)
bcc: Optional[List[str]] = Field(
None,
description="The list of BCC recipients.",
)
class GmailCreateDraft(GmailBaseTool):
name: str = "create_gmail_draft"
description: str = (
"Use this tool to create a draft email with the provided message fields."
)
args_schema: Type[CreateDraftSchema] = CreateDraftSchema
def _prepare_draft_message(
self,
message: str,
to: List[str],
subject: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
) -> dict:
draft_message = EmailMessage()
draft_message.set_content(message)
draft_message["To"] = ", ".join(to)
draft_message["Subject"] = subject
if cc is not None:
draft_message["Cc"] = ", ".join(cc)
if bcc is not None:
draft_message["Bcc"] = ", ".join(bcc)
encoded_message = base64.urlsafe_b64encode(draft_message.as_bytes()).decode()
return {"message": {"raw": encoded_message}}
def _run(
self,
message: str,
to: List[str],
subject: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
try:
create_message = self._prepare_draft_message(message, to, subject, cc, bcc)
draft = (
self.api_resource.users()
.drafts()
.create(userId="me", body=create_message)
.execute()
)
output = f'Draft created. Draft Id: {draft["id"]}'
return output
except Exception as e:
raise Exception(f"An error occurred: {e}")
async def _arun(
self,
message: str,
to: List[str],
subject: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError(f"The tool {self.name} does not support async yet.")

View File

@@ -0,0 +1,68 @@
import base64
import email
from typing import Dict, Optional, Type
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.gmail.base import GmailBaseTool
from langchain.tools.gmail.utils import clean_email_body
class SearchArgsSchema(BaseModel):
message_id: str = Field(
...,
description="The unique ID of the email message, retrieved from a search.",
)
class GmailGetMessage(GmailBaseTool):
name: str = "get_gmail_message"
description: str = (
"Use this tool to fetch an email by message ID."
" Returns the thread ID, snipet, body, subject, and sender."
)
args_schema: Type[SearchArgsSchema] = SearchArgsSchema
def _run(
self,
message_id: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Dict:
"""Run the tool."""
query = (
self.api_resource.users()
.messages()
.get(userId="me", format="raw", id=message_id)
)
message_data = query.execute()
raw_message = base64.urlsafe_b64decode(message_data["raw"])
email_msg = email.message_from_bytes(raw_message)
subject = email_msg["Subject"]
sender = email_msg["From"]
message_body = email_msg.get_payload()
body = clean_email_body(message_body)
return {
"id": message_id,
"threadId": message_data["threadId"],
"snippet": message_data["snippet"],
"body": body,
"subject": subject,
"sender": sender,
}
async def _arun(
self,
message_id: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Dict:
"""Run the tool."""
raise NotImplementedError

View File

@@ -0,0 +1,55 @@
from typing import Dict, Optional, Type
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.gmail.base import GmailBaseTool
class GetThreadSchema(BaseModel):
# From https://support.google.com/mail/answer/7190?hl=en
thread_id: str = Field(
...,
description="The thread ID.",
)
class GmailGetThread(GmailBaseTool):
name: str = "get_gmail_thread"
description: str = (
"Use this tool to search for email messages."
" The input must be a valid Gmail query."
" The output is a JSON list of messages."
)
args_schema: Type[GetThreadSchema] = GetThreadSchema
def _run(
self,
thread_id: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Dict:
"""Run the tool."""
query = self.api_resource.users().threads().get(userId="me", id=thread_id)
thread_data = query.execute()
if not isinstance(thread_data, dict):
raise ValueError("The output of the query must be a list.")
messages = thread_data["messages"]
thread_data["messages"] = []
keys_to_keep = ["id", "snippet", "snippet"]
# TODO: Parse body.
for message in messages:
thread_data["messages"].append(
{k: message[k] for k in keys_to_keep if k in message}
)
return thread_data
async def _arun(
self,
thread_id: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Dict:
"""Run the tool."""
raise NotImplementedError

View File

@@ -0,0 +1,138 @@
import base64
import email
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.gmail.base import GmailBaseTool
from langchain.tools.gmail.utils import clean_email_body
class Resource(str, Enum):
THREADS = "threads"
MESSAGES = "messages"
class SearchArgsSchema(BaseModel):
# From https://support.google.com/mail/answer/7190?hl=en
query: str = Field(
...,
description="The Gmail query. Example filters include from:sender,"
" to:recipient, subject:subject, -filtered_term,"
" in:folder, is:important|read|starred, after:year/mo/date, "
"before:year/mo/date, label:label_name"
' "exact phrase".'
" Search newer/older than using d (day), m (month), and y (year): "
"newer_than:2d, older_than:1y."
" Attachments with extension example: filename:pdf. Multiple term"
" matching example: from:amy OR from:david.",
)
resource: Resource = Field(
default=Resource.MESSAGES,
description="Whether to search for threads or messages.",
)
max_results: int = Field(
default=10,
description="The maximum number of results to return.",
)
class GmailSearch(GmailBaseTool):
name: str = "search_gmail"
description: str = (
"Use this tool to search for email messages or threads."
" The input must be a valid Gmail query."
" The output is a JSON list of the requested resource."
)
args_schema: Type[SearchArgsSchema] = SearchArgsSchema
def _parse_threads(self, threads: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Add the thread message snippets to the thread results
results = []
for thread in threads:
thread_id = thread["id"]
thread_data = (
self.api_resource.users()
.threads()
.get(userId="me", id=thread_id)
.execute()
)
messages = thread_data["messages"]
thread["messages"] = []
for message in messages:
snippet = message["snippet"]
thread["messages"].append({"snippet": snippet, "id": message["id"]})
results.append(thread)
return results
def _parse_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
results = []
for message in messages:
message_id = message["id"]
message_data = (
self.api_resource.users()
.messages()
.get(userId="me", format="raw", id=message_id)
.execute()
)
raw_message = base64.urlsafe_b64decode(message_data["raw"])
email_msg = email.message_from_bytes(raw_message)
subject = email_msg["Subject"]
sender = email_msg["From"]
message_body = email_msg.get_payload()
body = clean_email_body(message_body)
results.append(
{
"id": message["id"],
"threadId": message_data["threadId"],
"snippet": message_data["snippet"],
"body": body,
"subject": subject,
"sender": sender,
}
)
return results
def _run(
self,
query: str,
resource: Resource = Resource.MESSAGES,
max_results: int = 10,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Dict[str, Any]]:
"""Run the tool."""
results = (
self.api_resource.users()
.messages()
.list(userId="me", q=query, maxResults=max_results)
.execute()
.get(resource.value, [])
)
if resource == Resource.THREADS:
return self._parse_threads(results)
elif resource == Resource.MESSAGES:
return self._parse_messages(results)
else:
raise NotImplementedError(f"Resource of type {resource} not implemented.")
async def _arun(
self,
query: str,
resource: Resource = Resource.MESSAGES,
max_results: int = 10,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> List[Dict[str, Any]]:
"""Run the tool."""
raise NotImplementedError

View File

@@ -0,0 +1,100 @@
"""Send Gmail messages."""
import base64
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.gmail.base import GmailBaseTool
class SendMessageSchema(BaseModel):
message: str = Field(
...,
description="The message to send.",
)
to: List[str] = Field(
...,
description="The list of recipients.",
)
subject: str = Field(
...,
description="The subject of the message.",
)
cc: Optional[List[str]] = Field(
None,
description="The list of CC recipients.",
)
bcc: Optional[List[str]] = Field(
None,
description="The list of BCC recipients.",
)
class GmailSendMessage(GmailBaseTool):
name: str = "send_gmail_message"
description: str = (
"Use this tool to send email messages." " The input is the message, recipents"
)
def _prepare_message(
self,
message: str,
to: List[str],
subject: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Create a message for an email."""
mime_message = MIMEMultipart()
mime_message.attach(MIMEText(message, "html"))
mime_message["To"] = ", ".join(to)
mime_message["Subject"] = subject
if cc is not None:
mime_message["Cc"] = ", ".join(cc)
if bcc is not None:
mime_message["Bcc"] = ", ".join(bcc)
encoded_message = base64.urlsafe_b64encode(mime_message.as_bytes()).decode()
return {"raw": encoded_message}
def _run(
self,
message: str,
to: List[str],
subject: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Run the tool."""
try:
create_message = self._prepare_message(message, to, subject, cc=cc, bcc=bcc)
send_message = (
self.api_resource.users()
.messages()
.send(userId="me", body=create_message)
)
sent_message = send_message.execute()
return f'Message sent. Message Id: {sent_message["id"]}'
except Exception as error:
raise Exception(f"An error occurred: {error}")
async def _arun(
self,
message: str,
to: List[str],
subject: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Run the tool asynchronously."""
raise NotImplementedError(f"The tool {self.name} does not support async yet.")

View File

@@ -0,0 +1,117 @@
"""Gmail tool utils."""
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING, List, Optional, Tuple
if TYPE_CHECKING:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import Resource
from googleapiclient.discovery import build as build_resource
logger = logging.getLogger(__name__)
def import_google() -> Tuple[Request, Credentials]:
# google-auth-httplib2
try:
from google.auth.transport.requests import Request # noqa: F401
from google.oauth2.credentials import Credentials # noqa: F401
except ImportError:
raise ValueError(
"You need to install google-auth-httplib2 to use this toolkit. "
"Try running pip install --upgrade google-auth-httplib2"
)
return Request, Credentials
def import_installed_app_flow() -> InstalledAppFlow:
try:
from google_auth_oauthlib.flow import InstalledAppFlow
except ImportError:
raise ValueError(
"You need to install google-auth-oauthlib to use this toolkit. "
"Try running pip install --upgrade google-auth-oauthlib"
)
return InstalledAppFlow
def import_googleapiclient_resource_builder() -> build_resource:
try:
from googleapiclient.discovery import build
except ImportError:
raise ValueError(
"You need to install googleapiclient to use this toolkit. "
"Try running pip install --upgrade google-api-python-client"
)
return build
DEFAULT_SCOPES = ["https://mail.google.com/"]
DEFAULT_CREDS_TOKEN_FILE = "token.json"
DEFAULT_CLIENT_SECRETS_FILE = "credentials.json"
def get_gmail_credentials(
token_file: Optional[str] = None,
client_secrets_file: Optional[str] = None,
scopes: Optional[List[str]] = None,
) -> Credentials:
"""Get credentials."""
# From https://developers.google.com/gmail/api/quickstart/python
Request, Credentials = import_google()
InstalledAppFlow = import_installed_app_flow()
creds = None
scopes = scopes or DEFAULT_SCOPES
token_file = token_file or DEFAULT_CREDS_TOKEN_FILE
client_secrets_file = client_secrets_file or DEFAULT_CLIENT_SECRETS_FILE
# The file token.json stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first
# time.
if os.path.exists(token_file):
creds = Credentials.from_authorized_user_file(token_file, scopes)
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
# https://developers.google.com/gmail/api/quickstart/python#authorize_credentials_for_a_desktop_application # noqa
flow = InstalledAppFlow.from_client_secrets_file(
client_secrets_file, scopes
)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open(token_file, "w") as token:
token.write(creds.to_json())
return creds
def build_resource_service(
credentials: Optional[Credentials] = None,
service_name: str = "gmail",
service_version: str = "v1",
) -> Resource:
"""Build a Gmail service."""
credentials = credentials or get_gmail_credentials()
builder = import_googleapiclient_resource_builder()
return builder(service_name, service_version, credentials=credentials)
def clean_email_body(body: str) -> str:
"""Clean email body."""
try:
from bs4 import BeautifulSoup
try:
soup = BeautifulSoup(str(body), "html.parser")
body = soup.get_text()
return str(body)
except Exception as e:
logger.error(e)
return str(body)
except ImportError:
logger.warning("BeautifulSoup not installed. Skipping cleaning.")
return str(body)