mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
add LLMBashChain to experimental (#11305)
Add LLMBashChain to experimental
This commit is contained in:
parent
29b9a890d4
commit
5e2d5047af
@ -0,0 +1 @@
|
|||||||
|
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
|
125
libs/experimental/langchain_experimental/llm_bash/base.py
Normal file
125
libs/experimental/langchain_experimental/llm_bash/base.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""Chain that interprets a prompt and executes bash operations."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.schema import BasePromptTemplate, OutputParserException
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
|
||||||
|
from langchain_experimental.llm_bash.bash import BashProcess
|
||||||
|
from langchain_experimental.llm_bash.prompt import PROMPT
|
||||||
|
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMBashChain(Chain):
|
||||||
|
"""Chain that interprets a prompt and executes bash operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chains import LLMBashChain
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
llm_bash = LLMBashChain.from_llm(OpenAI())
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_chain: LLMChain
|
||||||
|
llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""[Deprecated] LLM wrapper to use."""
|
||||||
|
input_key: str = "question" #: :meta private:
|
||||||
|
output_key: str = "answer" #: :meta private:
|
||||||
|
prompt: BasePromptTemplate = PROMPT
|
||||||
|
"""[Deprecated]"""
|
||||||
|
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
|
if "llm" in values:
|
||||||
|
warnings.warn(
|
||||||
|
"Directly instantiating an LLMBashChain with an llm is deprecated. "
|
||||||
|
"Please instantiate with llm_chain or using the from_llm class method."
|
||||||
|
)
|
||||||
|
if "llm_chain" not in values and values["llm"] is not None:
|
||||||
|
prompt = values.get("prompt", PROMPT)
|
||||||
|
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_prompt(cls, values: Dict) -> Dict:
|
||||||
|
if values["llm_chain"].prompt.output_parser is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The prompt used by llm_chain is expected to have an output_parser."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
|
||||||
|
|
||||||
|
t = self.llm_chain.predict(
|
||||||
|
question=inputs[self.input_key], callbacks=_run_manager.get_child()
|
||||||
|
)
|
||||||
|
_run_manager.on_text(t, color="green", verbose=self.verbose)
|
||||||
|
t = t.strip()
|
||||||
|
try:
|
||||||
|
parser = self.llm_chain.prompt.output_parser
|
||||||
|
command_list = parser.parse(t) # type: ignore[union-attr]
|
||||||
|
except OutputParserException as e:
|
||||||
|
_run_manager.on_chain_error(e, verbose=self.verbose)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
_run_manager.on_text("\nCode: ", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
str(command_list), color="yellow", verbose=self.verbose
|
||||||
|
)
|
||||||
|
output = self.bash_process.run(command_list)
|
||||||
|
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
|
||||||
|
return {self.output_key: output}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "llm_bash_chain"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: BasePromptTemplate = PROMPT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMBashChain:
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
return cls(llm_chain=llm_chain, **kwargs)
|
184
libs/experimental/langchain_experimental/llm_bash/bash.py
Normal file
184
libs/experimental/langchain_experimental/llm_bash/bash.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
"""Wrapper around subprocess to run commands."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
from typing import TYPE_CHECKING, List, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pexpect
|
||||||
|
|
||||||
|
|
||||||
|
class BashProcess:
|
||||||
|
"""
|
||||||
|
Wrapper class for starting subprocesses.
|
||||||
|
Uses the python built-in subprocesses.run()
|
||||||
|
Persistent processes are **not** available
|
||||||
|
on Windows systems, as pexpect makes use of
|
||||||
|
Unix pseudoterminals (ptys). MacOS and Linux
|
||||||
|
are okay.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
|
bash = BashProcess(
|
||||||
|
strip_newlines = False,
|
||||||
|
return_err_output = False,
|
||||||
|
persistent = False
|
||||||
|
)
|
||||||
|
bash.run('echo \'hello world\'')
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
strip_newlines: bool = False
|
||||||
|
"""Whether or not to run .strip() on the output"""
|
||||||
|
return_err_output: bool = False
|
||||||
|
"""Whether or not to return the output of a failed
|
||||||
|
command, or just the error message and stacktrace"""
|
||||||
|
persistent: bool = False
|
||||||
|
"""Whether or not to spawn a persistent session
|
||||||
|
NOTE: Unavailable for Windows environments"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
strip_newlines: bool = False,
|
||||||
|
return_err_output: bool = False,
|
||||||
|
persistent: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes with default settings
|
||||||
|
"""
|
||||||
|
self.strip_newlines = strip_newlines
|
||||||
|
self.return_err_output = return_err_output
|
||||||
|
self.prompt = ""
|
||||||
|
self.process = None
|
||||||
|
if persistent:
|
||||||
|
self.prompt = str(uuid4())
|
||||||
|
self.process = self._initialize_persistent_process(self, self.prompt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lazy_import_pexpect() -> pexpect:
|
||||||
|
"""Import pexpect only when needed."""
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
raise ValueError(
|
||||||
|
"Persistent bash processes are not yet supported on Windows."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import pexpect
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pexpect required for persistent bash processes."
|
||||||
|
" To install, run `pip install pexpect`."
|
||||||
|
)
|
||||||
|
return pexpect
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _initialize_persistent_process(self: BashProcess, prompt: str) -> pexpect.spawn:
|
||||||
|
# Start bash in a clean environment
|
||||||
|
# Doesn't work on windows
|
||||||
|
"""
|
||||||
|
Initializes a persistent bash setting in a
|
||||||
|
clean environment.
|
||||||
|
NOTE: Unavailable on Windows
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Prompt(str): the bash command to execute
|
||||||
|
""" # noqa: E501
|
||||||
|
pexpect = self._lazy_import_pexpect()
|
||||||
|
process = pexpect.spawn(
|
||||||
|
"env", ["-i", "bash", "--norc", "--noprofile"], encoding="utf-8"
|
||||||
|
)
|
||||||
|
# Set the custom prompt
|
||||||
|
process.sendline("PS1=" + prompt)
|
||||||
|
|
||||||
|
process.expect_exact(prompt, timeout=10)
|
||||||
|
return process
|
||||||
|
|
||||||
|
def run(self, commands: Union[str, List[str]]) -> str:
|
||||||
|
"""
|
||||||
|
Run commands in either an existing persistent
|
||||||
|
subprocess or on in a new subprocess environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commands(List[str]): a list of commands to
|
||||||
|
execute in the session
|
||||||
|
""" # noqa: E501
|
||||||
|
if isinstance(commands, str):
|
||||||
|
commands = [commands]
|
||||||
|
commands = ";".join(commands)
|
||||||
|
if self.process is not None:
|
||||||
|
return self._run_persistent(
|
||||||
|
commands,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._run(commands)
|
||||||
|
|
||||||
|
def _run(self, command: str) -> str:
|
||||||
|
"""
|
||||||
|
Runs a command in a subprocess and returns
|
||||||
|
the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: The command to run
|
||||||
|
""" # noqa: E501
|
||||||
|
try:
|
||||||
|
output = subprocess.run(
|
||||||
|
command,
|
||||||
|
shell=True,
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
).stdout.decode()
|
||||||
|
except subprocess.CalledProcessError as error:
|
||||||
|
if self.return_err_output:
|
||||||
|
return error.stdout.decode()
|
||||||
|
return str(error)
|
||||||
|
if self.strip_newlines:
|
||||||
|
output = output.strip()
|
||||||
|
return output
|
||||||
|
|
||||||
|
def process_output(self, output: str, command: str) -> str:
|
||||||
|
"""
|
||||||
|
Uses regex to remove the command from the output
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: a process' output string
|
||||||
|
command: the executed command
|
||||||
|
""" # noqa: E501
|
||||||
|
pattern = re.escape(command) + r"\s*\n"
|
||||||
|
output = re.sub(pattern, "", output, count=1)
|
||||||
|
return output.strip()
|
||||||
|
|
||||||
|
def _run_persistent(self, command: str) -> str:
|
||||||
|
"""
|
||||||
|
Runs commands in a persistent environment
|
||||||
|
and returns the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: the command to execute
|
||||||
|
""" # noqa: E501
|
||||||
|
pexpect = self._lazy_import_pexpect()
|
||||||
|
if self.process is None:
|
||||||
|
raise ValueError("Process not initialized")
|
||||||
|
self.process.sendline(command)
|
||||||
|
|
||||||
|
# Clear the output with an empty string
|
||||||
|
self.process.expect(self.prompt, timeout=10)
|
||||||
|
self.process.sendline("")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.process.expect([self.prompt, pexpect.EOF], timeout=10)
|
||||||
|
except pexpect.TIMEOUT:
|
||||||
|
return f"Timeout error while executing command {command}"
|
||||||
|
if self.process.after == pexpect.EOF:
|
||||||
|
return f"Exited with error status: {self.process.exitstatus}"
|
||||||
|
output = self.process.before
|
||||||
|
output = self.process_output(output, command)
|
||||||
|
if self.strip_newlines:
|
||||||
|
return output.strip()
|
||||||
|
return output
|
64
libs/experimental/langchain_experimental/llm_bash/prompt.py
Normal file
64
libs/experimental/langchain_experimental/llm_bash/prompt.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
|
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
|
||||||
|
|
||||||
|
Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'"
|
||||||
|
|
||||||
|
I need to take the following actions:
|
||||||
|
- List all files in the directory
|
||||||
|
- Create a new directory
|
||||||
|
- Copy the files from the first directory into the second directory
|
||||||
|
```bash
|
||||||
|
ls
|
||||||
|
mkdir myNewDirectory
|
||||||
|
cp -r target/* myNewDirectory
|
||||||
|
```
|
||||||
|
|
||||||
|
That is the format. Begin!
|
||||||
|
|
||||||
|
Question: {question}"""
|
||||||
|
|
||||||
|
|
||||||
|
class BashOutputParser(BaseOutputParser):
|
||||||
|
"""Parser for bash output."""
|
||||||
|
|
||||||
|
def parse(self, text: str) -> List[str]:
|
||||||
|
if "```bash" in text:
|
||||||
|
return self.get_code_blocks(text)
|
||||||
|
else:
|
||||||
|
raise OutputParserException(
|
||||||
|
f"Failed to parse bash output. Got: {text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_code_blocks(t: str) -> List[str]:
|
||||||
|
"""Get multiple code blocks from the LLM result."""
|
||||||
|
code_blocks: List[str] = []
|
||||||
|
# Bash markdown code blocks
|
||||||
|
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
|
||||||
|
for match in pattern.finditer(t):
|
||||||
|
matched = match.group(1).strip()
|
||||||
|
if matched:
|
||||||
|
code_blocks.extend(
|
||||||
|
[line for line in matched.split("\n") if line.strip()]
|
||||||
|
)
|
||||||
|
|
||||||
|
return code_blocks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "bash"
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT = PromptTemplate(
|
||||||
|
input_variables=["question"],
|
||||||
|
template=_PROMPT_TEMPLATE,
|
||||||
|
output_parser=BashOutputParser(),
|
||||||
|
)
|
102
libs/experimental/tests/unit_tests/test_bash.py
Normal file
102
libs/experimental/tests/unit_tests/test_bash.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""Test the bash utility."""
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_experimental.llm_bash.bash import BashProcess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_pwd_command() -> None:
|
||||||
|
"""Test correct functionality."""
|
||||||
|
session = BashProcess()
|
||||||
|
commands = ["pwd"]
|
||||||
|
output = session.run(commands)
|
||||||
|
|
||||||
|
assert output == subprocess.check_output("pwd", shell=True).decode()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_pwd_command_persistent() -> None:
|
||||||
|
"""Test correct functionality when the bash process is persistent."""
|
||||||
|
session = BashProcess(persistent=True, strip_newlines=True)
|
||||||
|
commands = ["pwd"]
|
||||||
|
output = session.run(commands)
|
||||||
|
|
||||||
|
assert subprocess.check_output("pwd", shell=True).decode().strip() in output
|
||||||
|
|
||||||
|
session.run(["cd .."])
|
||||||
|
new_output = session.run(["pwd"])
|
||||||
|
# Assert that the new_output is a parent of the old output
|
||||||
|
assert Path(output).parent == Path(new_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_incorrect_command() -> None:
|
||||||
|
"""Test handling of incorrect command."""
|
||||||
|
session = BashProcess()
|
||||||
|
output = session.run(["invalid_command"])
|
||||||
|
assert output == "Command 'invalid_command' returned non-zero exit status 127."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_incorrect_command_return_err_output() -> None:
|
||||||
|
"""Test optional returning of shell output on incorrect command."""
|
||||||
|
session = BashProcess(return_err_output=True)
|
||||||
|
output = session.run(["invalid_command"])
|
||||||
|
assert re.match(
|
||||||
|
r"^/bin/sh:.*invalid_command.*(?:not found|Permission denied).*$", output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_create_directory_and_files(tmp_path: Path) -> None:
|
||||||
|
"""Test creation of a directory and files in a temporary directory."""
|
||||||
|
session = BashProcess(strip_newlines=True)
|
||||||
|
|
||||||
|
# create a subdirectory in the temporary directory
|
||||||
|
temp_dir = tmp_path / "test_dir"
|
||||||
|
temp_dir.mkdir()
|
||||||
|
|
||||||
|
# run the commands in the temporary directory
|
||||||
|
commands = [
|
||||||
|
f"touch {temp_dir}/file1.txt",
|
||||||
|
f"touch {temp_dir}/file2.txt",
|
||||||
|
f"echo 'hello world' > {temp_dir}/file2.txt",
|
||||||
|
f"cat {temp_dir}/file2.txt",
|
||||||
|
]
|
||||||
|
|
||||||
|
output = session.run(commands)
|
||||||
|
assert output == "hello world"
|
||||||
|
|
||||||
|
# check that the files were created in the temporary directory
|
||||||
|
output = session.run([f"ls {temp_dir}"])
|
||||||
|
assert output == "file1.txt\nfile2.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="flaky on GHA, TODO to fix")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_create_bash_persistent() -> None:
|
||||||
|
"""Test the pexpect persistent bash terminal"""
|
||||||
|
session = BashProcess(persistent=True)
|
||||||
|
response = session.run("echo hello")
|
||||||
|
response += session.run("echo world")
|
||||||
|
|
||||||
|
assert "hello" in response
|
||||||
|
assert "world" in response
|
109
libs/experimental/tests/unit_tests/test_llm_bash.py
Normal file
109
libs/experimental/tests/unit_tests/test_llm_bash.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""Test LLM Bash functionality."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.schema import OutputParserException
|
||||||
|
|
||||||
|
from langchain_experimental.llm_bash.base import LLMBashChain
|
||||||
|
from langchain_experimental.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
|
||||||
|
from tests.unit_tests.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
_SAMPLE_CODE = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
```
|
||||||
|
Unrelated text
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
_SAMPLE_CODE_2_LINES = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
|
||||||
|
echo world
|
||||||
|
```
|
||||||
|
Unrelated text
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def output_parser() -> BashOutputParser:
|
||||||
|
"""Output parser for testing."""
|
||||||
|
return BashOutputParser()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform.startswith("win"), reason="Test not supported on Windows"
|
||||||
|
)
|
||||||
|
def test_simple_question() -> None:
|
||||||
|
"""Test simple question that should not need python."""
|
||||||
|
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||||
|
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||||
|
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||||
|
output = fake_llm_bash_chain.run(question)
|
||||||
|
assert output == "2\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_code(output_parser: BashOutputParser) -> None:
|
||||||
|
"""Test the parser."""
|
||||||
|
code_lines = output_parser.parse(_SAMPLE_CODE)
|
||||||
|
code = [c for c in code_lines if c.strip()]
|
||||||
|
assert code == code_lines
|
||||||
|
assert code == ["echo hello"]
|
||||||
|
|
||||||
|
code_lines = output_parser.parse(_SAMPLE_CODE + _SAMPLE_CODE_2_LINES)
|
||||||
|
assert code_lines == ["echo hello", "echo hello", "echo world"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parsing_error() -> None:
|
||||||
|
"""Test that LLM Output without a bash block raises an exce"""
|
||||||
|
question = "Please echo 'hello world' to the terminal."
|
||||||
|
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||||
|
queries = {
|
||||||
|
prompt: """
|
||||||
|
```text
|
||||||
|
echo 'hello world'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
fake_llm_bash_chain.run(question)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_code_lines_mixed_blocks(output_parser: BashOutputParser) -> None:
|
||||||
|
text = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
ls && pwd && ls
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
print("hello")
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo goodbye
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
code_lines = output_parser.parse(text)
|
||||||
|
assert code_lines == ["echo hello", "ls && pwd && ls", "echo goodbye"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_code_lines_simple_nested_ticks(output_parser: BashOutputParser) -> None:
|
||||||
|
"""Test that backticks w/o a newline are ignored."""
|
||||||
|
text = """
|
||||||
|
Unrelated text
|
||||||
|
```bash
|
||||||
|
echo hello
|
||||||
|
echo "```bash is in this string```"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
code_lines = output_parser.parse(text)
|
||||||
|
assert code_lines == ["echo hello", 'echo "```bash is in this string```"']
|
Loading…
Reference in New Issue
Block a user