E2B tool - Improve description wuth uploaded files info (#12355)

This commit is contained in:
Jakub Novák 2023-10-26 11:44:24 -07:00 committed by GitHub
parent dad16af711
commit 9544d64ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain.pydantic_v1 import BaseModel, Field from langchain.pydantic_v1 import BaseModel, Field, PrivateAttr
from langchain.tools import BaseTool, Tool from langchain.tools import BaseTool, Tool
from langchain.tools.e2b_data_analysis.unparse import Unparser from langchain.tools.e2b_data_analysis.unparse import Unparser
@ -97,7 +97,7 @@ class E2BDataAnalysisTool(BaseTool):
name = "e2b_data_analysis" name = "e2b_data_analysis"
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
session: Any session: Any
uploaded_files: List[UploadedFile] = Field(default_factory=list) _uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)
def __init__( def __init__(
self, self,
@ -119,7 +119,8 @@ class E2BDataAnalysisTool(BaseTool):
# If no API key is provided, E2B will try to read it from the environment # If no API key is provided, E2B will try to read it from the environment
# variable E2B_API_KEY # variable E2B_API_KEY
session = DataAnalysis( super().__init__(description=base_description, **kwargs)
self.session = DataAnalysis(
api_key=api_key, api_key=api_key,
cwd=cwd, cwd=cwd,
env_vars=env_vars, env_vars=env_vars,
@ -128,21 +129,19 @@ class E2BDataAnalysisTool(BaseTool):
on_exit=on_exit, on_exit=on_exit,
on_artifact=on_artifact, on_artifact=on_artifact,
) )
super().__init__(session=session, description=base_description, **kwargs)
self.uploaded_files = []
def close(self) -> None: def close(self) -> None:
"""Close the cloud sandbox.""" """Close the cloud sandbox."""
self.uploaded_files = [] self._uploaded_files = []
self.session.close() self.session.close()
@property @property
def uploaded_files_description(self) -> str: def uploaded_files_description(self) -> str:
if len(self.uploaded_files) == 0: if len(self._uploaded_files) == 0:
return "" return ""
lines = ["The following files available in the sandbox:"] lines = ["The following files available in the sandbox:"]
for f in self.uploaded_files: for f in self._uploaded_files:
if f.description == "": if f.description == "":
lines.append(f"- path: `{f.remote_path}`") lines.append(f"- path: `{f.remote_path}`")
else: else:
@ -206,15 +205,19 @@ class E2BDataAnalysisTool(BaseTool):
remote_path=remote_path, remote_path=remote_path,
description=description, description=description,
) )
self.uploaded_files.append(f) self._uploaded_files.append(f)
self.description = self.description + "\n" + self.uploaded_files_description
return f return f
def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None: def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
"""Remove uploaded file from the sandbox.""" """Remove uploaded file from the sandbox."""
self.session.filesystem.remove(uploaded_file.remote_path) self.session.filesystem.remove(uploaded_file.remote_path)
self.uploaded_files = [ self._uploaded_files = [
f for f in self.uploaded_files if f.remote_path != uploaded_file.remote_path f
for f in self._uploaded_files
if f.remote_path != uploaded_file.remote_path
] ]
self.description = self.description + "\n" + self.uploaded_files_description
def as_tool(self) -> Tool: def as_tool(self) -> Tool:
return Tool.from_function( return Tool.from_function(