Custom prompt option for llm_bash and api chains (#612)

Co-authored-by: lesscomfortable <pancho_ingham@hotmail.com>
This commit is contained in:
Francisco Ingham 2023-01-14 12:22:52 -03:00 committed by GitHub
parent 67808bad0e
commit 1787c473b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 8 deletions

View File

@ -28,7 +28,7 @@
"\n", "\n",
"Answer: \u001b[33;1m\u001b[1;3mHello World\n", "Answer: \u001b[33;1m\u001b[1;3mHello World\n",
"\u001b[0m\n", "\u001b[0m\n",
"\u001b[1m> Finished LLMBashChain chain.\u001b[0m\n" "\u001b[1m> Finished chain.\u001b[0m\n"
] ]
}, },
{ {
@ -55,12 +55,83 @@
"bash_chain.run(text)" "bash_chain.run(text)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Customize Prompt\n",
"You can also customize the prompt that is used. Here is an example prompting to avoid using the 'echo' utility"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"_PROMPT_TEMPLATE = \"\"\"If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put \"#!/bin/bash\" in your answer. Make sure to reason step by step, using this format:\n",
"Question: \"copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'\"\n",
"I need to take the following actions:\n",
"- List all files in the directory\n",
"- Create a new directory\n",
"- Copy the files from the first directory into the second directory\n",
"```bash\n",
"ls\n",
"mkdir myNewDirectory\n",
"cp -r target/* myNewDirectory\n",
"```\n",
"\n",
"Do not use 'echo' when writing the script.\n",
"\n",
"That is the format. Begin!\n",
"Question: {question}\"\"\"\n",
"\n",
"PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMBashChain chain...\u001b[0m\n",
"Please write a bash script that prints 'Hello World' to the console.\u001b[32;1m\u001b[1;3m\n",
"\n",
"```bash\n",
"printf \"Hello World\\n\"\n",
"```\u001b[0m['```bash', 'printf \"Hello World\\\\n\"', '```']\n",
"\n",
"Answer: \u001b[33;1m\u001b[1;3mHello World\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Hello World\\n'"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bash_chain = LLMBashChain(llm=llm, prompt=PROMPT, verbose=True)\n",
"\n",
"text = \"Please write a bash script that prints 'Hello World' to the console.\"\n",
"\n",
"bash_chain.run(text)"
]
} }
], ],
"metadata": { "metadata": {
@ -79,7 +150,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.9" "version": "3.10.6"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -9,6 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
@ -80,12 +81,18 @@ class APIChain(Chain, BaseModel):
@classmethod @classmethod
def from_llm_and_api_docs( def from_llm_and_api_docs(
cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any cls,
llm: BaseLLM,
api_docs: str,
headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
**kwargs: Any,
) -> APIChain: ) -> APIChain:
"""Load chain from just an LLM and the api docs.""" """Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT) get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt)
requests_wrapper = RequestsWrapper(headers=headers) requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT) get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt)
return cls( return cls(
api_request_chain=get_request_chain, api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain, api_answer_chain=get_answer_chain,

View File

@ -7,6 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT from langchain.chains.llm_bash.prompt import PROMPT
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
@ -24,6 +25,7 @@ class LLMBashChain(Chain, BaseModel):
"""LLM wrapper to use.""" """LLM wrapper to use."""
input_key: str = "question" #: :meta private: input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private: output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -48,7 +50,7 @@ class LLMBashChain(Chain, BaseModel):
return [self.output_key] return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
bash_executor = BashProcess() bash_executor = BashProcess()
if self.verbose: if self.verbose:
self.callback_manager.on_text(inputs[self.input_key]) self.callback_manager.on_text(inputs[self.input_key])