mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
parent
5a269d3175
commit
65c3b146c9
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import platform
|
||||
import warnings
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
@ -16,7 +16,7 @@ from langchain.utilities.bash import BashProcess
|
||||
class ShellInput(BaseModel):
|
||||
"""Commands for the Bash Shell tool."""
|
||||
|
||||
commands: List[str] = Field(
|
||||
commands: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="List of shell commands to run. Deserialized using json.loads",
|
||||
)
|
||||
@ -66,7 +66,7 @@ class ShellTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
commands: List[str],
|
||||
commands: Union[str, List[str]],
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Run commands and return final output."""
|
||||
@ -74,7 +74,7 @@ class ShellTool(BaseTool):
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
commands: List[str],
|
||||
commands: Union[str, List[str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Run commands asynchronously and return final output."""
|
||||
|
@ -30,6 +30,12 @@ def test_shell_tool_init() -> None:
|
||||
assert shell_tool.process is not None
|
||||
|
||||
|
||||
def test_shell_tool_run() -> None:
|
||||
shell_tool = ShellTool()
|
||||
result = shell_tool._run(commands=test_commands)
|
||||
assert result.strip() == "Hello, World!\nAnother command"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shell_tool_arun() -> None:
|
||||
shell_tool = ShellTool()
|
||||
@ -37,7 +43,7 @@ async def test_shell_tool_arun() -> None:
|
||||
assert result.strip() == "Hello, World!\nAnother command"
|
||||
|
||||
|
||||
def test_shell_tool_run() -> None:
|
||||
def test_shell_tool_run_str() -> None:
|
||||
shell_tool = ShellTool()
|
||||
result = shell_tool._run(commands=test_commands)
|
||||
assert result.strip() == "Hello, World!\nAnother command"
|
||||
result = shell_tool._run(commands="echo 'Hello, World!'")
|
||||
assert result.strip() == "Hello, World!"
|
||||
|
Loading…
Reference in New Issue
Block a user