mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
BashChain (#260)
Love the project, a ton of fun! I think the PR is pretty self-explanatory, happy to make any changes! I am working on using it in an `LLMBashChain` and may update as that progresses. Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
28be37f470
commit
b7bef36ee1
86
docs/examples/chains/llm_bash.ipynb
Normal file
86
docs/examples/chains/llm_bash.ipynb
Normal file
@ -0,0 +1,86 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# LLM Chain\n",
|
||||
"This notebook showcases using LLMs and a bash process to do perform simple filesystem commands."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
|
||||
"Please write a bash script that prints 'Hello World' to the console.\u001b[32;1m\u001b[1;3m\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"echo \"Hello World\"\n",
|
||||
"```\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m{'success': True, 'outputs': ['Hello World\\n']}\u001b[0m\n",
|
||||
"\u001b[1m> Finished LLMBashChain chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'commands': ['echo \"Hello World\"'],\n",
|
||||
" 'output': {'success': True, 'outputs': ['Hello World\\n']}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains.llm_bash.base import LLMBashChain\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"\n",
|
||||
"text = \"Please write a bash script that prints 'Hello World' to the console.\"\n",
|
||||
"\n",
|
||||
"bash_chain = LLMBashChain(llm=llm, verbose=True)\n",
|
||||
"\n",
|
||||
"bash_chain.run(text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -4,6 +4,7 @@ from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.chains import (
|
||||
ConversationChain,
|
||||
LLMChain,
|
||||
LLMBashChain,
|
||||
LLMMathChain,
|
||||
PALChain,
|
||||
QAWithSourcesChain,
|
||||
@ -25,6 +26,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
"LLMBashChain",
|
||||
"LLMMathChain",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIWrapper",
|
||||
|
@ -2,6 +2,7 @@
|
||||
from langchain.chains.api.base import APIChain
|
||||
from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.base import LLMBashChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.llm_requests import LLMRequestsChain
|
||||
from langchain.chains.pal.base import PALChain
|
||||
@ -12,16 +13,18 @@ from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
|
||||
__all__ = [
|
||||
"APIChain",
|
||||
"ConversationChain",
|
||||
"LLMChain",
|
||||
"LLMBashChain",
|
||||
"LLMMathChain",
|
||||
"PALChain",
|
||||
"QAWithSourcesChain",
|
||||
"SQLDatabaseChain",
|
||||
"VectorDBQA",
|
||||
"SequentialChain",
|
||||
"SimpleSequentialChain",
|
||||
"ConversationChain",
|
||||
"QAWithSourcesChain",
|
||||
"VectorDBQA",
|
||||
"VectorDBQAWithSourcesChain",
|
||||
"PALChain",
|
||||
"APIChain",
|
||||
"LLMRequestsChain",
|
||||
]
|
||||
|
77
langchain/chains/llm_bash/base.py
Normal file
77
langchain/chains/llm_bash/base.py
Normal file
@ -0,0 +1,77 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.utilities.bash import BashProcess
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class LLMBashChain(Chain, BaseModel):
|
||||
"""Chain that interprets a prompt and executes bash code to perform bash operations.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMBashChain, OpenAI
|
||||
llm_bash = LLMBashChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
"""LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Dict[str, list[str]]]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
bash_executor = BashProcess()
|
||||
if self.verbose:
|
||||
print_text(inputs[self.input_key])
|
||||
|
||||
t = llm_executor.predict(question=inputs[self.input_key])
|
||||
if self.verbose:
|
||||
print_text(t, color="green")
|
||||
|
||||
t = t.strip()
|
||||
if t.startswith("```bash"):
|
||||
# Split the string into a list of substrings
|
||||
command_list = t.split('\n')
|
||||
print(command_list)
|
||||
|
||||
# Remove the first and last substrings
|
||||
command_list = [s for s in command_list[1:-1]]
|
||||
output = bash_executor.run(command_list)
|
||||
|
||||
if self.verbose:
|
||||
print_text("\nAnswer: ")
|
||||
print_text(output, color="yellow")
|
||||
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {t}")
|
||||
answer = {"commands": command_list, "output": output}
|
||||
return {self.output_key: answer}
|
22
langchain/chains/llm_bash/prompt.py
Normal file
22
langchain/chains/llm_bash/prompt.py
Normal file
@ -0,0 +1,22 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
|
||||
|
||||
Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'"
|
||||
|
||||
I need to take the following actions:
|
||||
- List all files in the directory
|
||||
- Create a new directory
|
||||
- Copy the files from the first directory into the second directory
|
||||
```bash
|
||||
ls
|
||||
mkdir myNewDirectory
|
||||
cp -r target/* myNewDirectory
|
||||
```
|
||||
|
||||
That is the format. Begin!
|
||||
|
||||
Question: {question}"""
|
||||
|
||||
PROMPT = PromptTemplate(input_variables=["question"], template=_PROMPT_TEMPLATE)
|
6
langchain/utilities/__init__.py
Normal file
6
langchain/utilities/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BashProcess',
|
||||
]
|
23
langchain/utilities/bash.py
Normal file
23
langchain/utilities/bash.py
Normal file
@ -0,0 +1,23 @@
|
||||
import subprocess
|
||||
from typing import Dict, List, Union
|
||||
|
||||
class BashProcess:
|
||||
"""Executes bash commands and returns the output."""
|
||||
|
||||
def __init__(self, strip_newlines: bool = False):
|
||||
self.strip_newlines = strip_newlines
|
||||
|
||||
|
||||
def run(self, commands: List[str]) -> Dict[str, Union[bool, list[str]]]:
|
||||
outputs = []
|
||||
for command in commands:
|
||||
try:
|
||||
output = subprocess.check_output(command, shell=True).decode()
|
||||
if self.strip_newlines:
|
||||
output = output.strip()
|
||||
outputs.append(output)
|
||||
except subprocess.CalledProcessError as error:
|
||||
outputs.append(str(error))
|
||||
return {"success": False, "outputs": outputs}
|
||||
|
||||
return {"success": True, "outputs": outputs}
|
44
tests/unit_tests/chains/test_bash.py
Normal file
44
tests/unit_tests/chains/test_bash.py
Normal file
@ -0,0 +1,44 @@
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
def test_pwd_command() -> None:
|
||||
"""Test correct functionality."""
|
||||
session = BashProcess()
|
||||
commands = ["pwd"]
|
||||
output = session.run(commands)
|
||||
print(output)
|
||||
|
||||
assert output["outputs"] == [subprocess.check_output("pwd", shell=True).decode()]
|
||||
|
||||
def test_incorrect_command() -> None:
|
||||
"""Test handling of incorrect command."""
|
||||
session = BashProcess()
|
||||
output = session.run(["invalid_command"])
|
||||
assert output["success"] is False
|
||||
|
||||
def test_create_directory_and_files(tmp_path: Path) -> None:
|
||||
"""Test creation of a directory and files in a temporary directory."""
|
||||
session = BashProcess(strip_newlines=True)
|
||||
|
||||
# create a subdirectory in the temporary directory
|
||||
temp_dir = tmp_path / "test_dir"
|
||||
temp_dir.mkdir()
|
||||
|
||||
# run the commands in the temporary directory
|
||||
commands = [
|
||||
f"touch {temp_dir}/file1.txt",
|
||||
f"touch {temp_dir}/file2.txt",
|
||||
f"echo 'hello world' > {temp_dir}/file2.txt",
|
||||
f"cat {temp_dir}/file2.txt"
|
||||
]
|
||||
|
||||
output = session.run(commands)
|
||||
assert output["success"] is True
|
||||
assert output["outputs"][-1] == "hello world"
|
||||
|
||||
# check that the files were created in the temporary directory
|
||||
output = session.run([f"ls {temp_dir}"])
|
||||
assert output["success"] is True
|
||||
assert output["outputs"] == ["file1.txt\nfile2.txt"]
|
24
tests/unit_tests/chains/test_llm_bash.py
Normal file
24
tests/unit_tests/chains/test_llm_bash.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Test LLM Bash functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_bash.base import LLMBashChain
|
||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_bash_chain() -> LLMBashChain:
|
||||
"""Fake LLM Bash chain for testing."""
|
||||
queries = {
|
||||
_PROMPT_TEMPLATE.format(question="Please write a bash script that prints 'Hello World' to the console."): "```bash\nexpr 1 + 1\n```",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_bash_chain: LLMBashChain) -> None:
|
||||
"""Test simple question that should not need python."""
|
||||
question = "Please write a bash script that prints 'Hello World' to the console."
|
||||
output = fake_llm_bash_chain.run(question)
|
||||
assert output == {'commands': ['expr 1 + 1'], 'output': {'outputs': ['2\n'], 'success': True}}
|
Loading…
Reference in New Issue
Block a user