diff --git a/docs/extras/use_cases/self_check/smart_llm.ipynb b/docs/extras/use_cases/self_check/smart_llm.ipynb new file mode 100644 index 00000000000..617a36307ee --- /dev/null +++ b/docs/extras/use_cases/self_check/smart_llm.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "9e9b7651", + "metadata": {}, + "source": [ + "# How to use a SmartLLMChain\n", + "\n", + "A SmartLLMChain is a form of self-critique chain that can help you if have particularly complex questions to answer. Instead of doing a single LLM pass, it instead performs these 3 steps:\n", + "1. Ideation: Pass the user prompt n times through the LLM to get n output proposals (called \"ideas\"), where n is a parameter you can set \n", + "2. Critique: The LLM critiques all ideas to find possible flaws and picks the best one \n", + "3. Resolve: The LLM tries to improve upon the best idea (as chosen in the critique step) and outputs it. This is then the final output.\n", + "\n", + "SmartLLMChains are based on the SmartGPT workflow proposed in https://youtu.be/wVzuvf9D9BU.\n", + "\n", + "Note that SmartLLMChains\n", + "- use more LLM passes (ie n+2 instead of just 1)\n", + "- only work then the underlying LLM has the capability for reflection, whicher smaller models often don't\n", + "- only work with underlying models that return exactly 1 output, not multiple\n", + "\n", + "This notebook demonstrates how to use a SmartLLMChain." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "714dede0", + "metadata": {}, + "source": [ + "##### Same LLM for all steps" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d3f7fb22", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"...\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "10e5ece6", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain_experimental.smart_llm import SmartLLMChain" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1780da51", + "metadata": {}, + "source": [ + "As example question, we will use \"I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\". This is an example from the original SmartGPT video (https://youtu.be/wVzuvf9D9BU?t=384). While this seems like a very easy question, LLMs struggle do these kinds of questions that involve numbers and physical reasoning.\n", + "\n", + "As we will see, all 3 initial ideas are completely wrong - even though we're using GPT4! Only when using self-reflection do we get a correct answer. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "054af6b1", + "metadata": {}, + "outputs": [], + "source": [ + "hard_question = \"I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8049cecd", + "metadata": {}, + "source": [ + "So, we first create an LLM and prompt template" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "811ea8e1", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = PromptTemplate.from_template(hard_question)\n", + "llm = ChatOpenAI(temperature=0, model_name=\"gpt-4\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "50b602e4", + "metadata": {}, + "source": [ + "Now we can create a SmartLLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8cd49199", + "metadata": {}, + "outputs": [], + "source": [ + "chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=3, verbose=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "6a72f276", + "metadata": {}, + "source": [ + "Now we can use the SmartLLM as a drop-in replacement for our LLM. E.g.:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "074e5e75", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new SmartLLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mI have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\u001b[0m\n", + "Idea 1:\n", + "\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n", + "2. Pour the water from the 6-liter jug into the 12-liter jug.\n", + "3. Fill the 6-liter jug again.\n", + "4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n", + "5. The amount of water left in the 6-liter jug will be exactly 6 liters.\u001b[0m\n", + "Idea 2:\n", + "\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n", + "2. Pour the water from the 6-liter jug into the 12-liter jug.\n", + "3. Fill the 6-liter jug again.\n", + "4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n", + "5. Since the 12-liter jug is now full, there will be 2 liters of water left in the 6-liter jug.\n", + "6. Empty the 12-liter jug.\n", + "7. Pour the 2 liters of water from the 6-liter jug into the 12-liter jug.\n", + "8. Fill the 6-liter jug completely again.\n", + "9. Pour the water from the 6-liter jug into the 12-liter jug, which already has 2 liters in it.\n", + "10. Now, the 12-liter jug will have exactly 6 liters of water (2 liters from before + 4 liters from the 6-liter jug).\u001b[0m\n", + "Idea 3:\n", + "\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n", + "2. Pour the water from the 6-liter jug into the 12-liter jug.\n", + "3. Fill the 6-liter jug again.\n", + "4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n", + "5. The amount of water left in the 6-liter jug will be exactly 6 liters.\u001b[0m\n", + "Critique:\n", + "\u001b[33;1m\u001b[1;3mIdea 1:\n", + "1. Fill the 6-liter jug completely. (No flaw)\n", + "2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n", + "3. Fill the 6-liter jug again. (No flaw)\n", + "4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n", + "5. The amount of water left in the 6-liter jug will be exactly 6 liters. (Flaw: This statement is incorrect, as there will be no water left in the 6-liter jug after pouring it into the 12-liter jug.)\n", + "\n", + "Idea 2:\n", + "1. Fill the 6-liter jug completely. (No flaw)\n", + "2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n", + "3. Fill the 6-liter jug again. (No flaw)\n", + "4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n", + "5. Since the 12-liter jug is now full, there will be 2 liters of water left in the 6-liter jug. (Flaw: This statement is incorrect, as the 12-liter jug will not be full and there will be no water left in the 6-liter jug.)\n", + "6. Empty the 12-liter jug. (No flaw)\n", + "7. Pour the 2 liters of water from the 6-liter jug into the 12-liter jug. (Flaw: This step is based on the incorrect assumption that there are 2 liters of water left in the 6-liter jug.)\n", + "8. Fill the 6-liter jug completely again. (No flaw)\n", + "9. Pour the water from the 6-liter jug into the 12-liter jug, which already has 2 liters in it. (Flaw: This step is based on the incorrect assumption that there are 2 liters of water in the 12-liter jug.)\n", + "10. Now, the 12-liter jug will have exactly 6 liters of water (2 liters from before + 4 liters from the 6-liter jug). (Flaw: This conclusion is based on the incorrect assumptions made in the previous steps.)\n", + "\n", + "Idea 3:\n", + "1. Fill the 6-liter jug completely. (No flaw)\n", + "2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n", + "3. Fill the 6-liter jug again. (No flaw)\n", + "4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n", + "5. The amount of water left in the 6-liter jug will be exactly 6 liters. (Flaw: This statement is incorrect, as there will be no water left in the 6-liter jug after pouring it into the 12-liter jug.)\u001b[0m\n", + "Resolution:\n", + "\u001b[32;1m\u001b[1;3m1. Fill the 12-liter jug completely.\n", + "2. Pour the water from the 12-liter jug into the 6-liter jug until the 6-liter jug is full.\n", + "3. The amount of water left in the 12-liter jug will be exactly 6 liters.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'1. Fill the 12-liter jug completely.\\n2. Pour the water from the 12-liter jug into the 6-liter jug until the 6-liter jug is full.\\n3. The amount of water left in the 12-liter jug will be exactly 6 liters.'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run({})" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bbfebea1", + "metadata": {}, + "source": [ + "##### Different LLM for different steps" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "5be6ec08", + "metadata": {}, + "source": [ + "You can also use different LLMs for the different steps by passing `ideation_llm`, `critique_llm` and `resolve_llm`. You might want to do this to use a more creative (i.e., high-temperature) model for ideation and a more strict (i.e., low-temperature) model for critique and resolution." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9c33fa19", + "metadata": {}, + "outputs": [], + "source": [ + "chain = SmartLLMChain(\n", + " ideation_llm=ChatOpenAI(temperature=0.9, model_name=\"gpt-4\"),\n", + " llm=ChatOpenAI(\n", + " temperature=0, model_name=\"gpt-4\"\n", + " ), # will be used for critqiue and resolution as no specific llms are given\n", + " prompt=prompt,\n", + " n_ideas=3,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "886c1cc1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/experimental/langchain_experimental/smart_llm/__init__.py b/libs/experimental/langchain_experimental/smart_llm/__init__.py new file mode 100644 index 00000000000..925a05cec5b --- /dev/null +++ b/libs/experimental/langchain_experimental/smart_llm/__init__.py @@ -0,0 +1,5 @@ +"""Generalized implementation of SmartGPT (origin: https://youtu.be/wVzuvf9D9BU)""" + +from langchain_experimental.smart_llm.base import SmartLLMChain + +__all__ = ["SmartLLMChain"] diff --git a/libs/experimental/langchain_experimental/smart_llm/base.py b/libs/experimental/langchain_experimental/smart_llm/base.py new file mode 100644 index 00000000000..8495e30ae87 --- /dev/null +++ b/libs/experimental/langchain_experimental/smart_llm/base.py @@ -0,0 +1,323 @@ +"""Chain for applying self-critique using the SmartGPT workflow.""" +from typing import Any, Dict, List, Optional, Tuple, Type + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.input import get_colored_text +from langchain.prompts.base import BasePromptTemplate +from langchain.prompts.chat import ( + AIMessagePromptTemplate, + BaseMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, +) +from langchain.schema import LLMResult, PromptValue +from pydantic import Extra, root_validator + + +class SmartLLMChain(Chain): + """ + Generalized implementation of SmartGPT (origin: https://youtu.be/wVzuvf9D9BU) + + A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM + performs these 3 steps: + 1. Ideate: Pass the user prompt to an ideation LLM n_ideas times, + each result is an "idea" + 2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas + & picks the best one + 3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea + & outputs only the (improved version of) the best output + + In total, SmartLLMChain pass will use n_ideas+2 LLM calls + + Note that SmartLLMChain will only improve results (compared to a basic LLMChain), + when the underlying models have the capability for reflection, which smaller models + often don't. + + Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result. + """ + + class SmartLLMChainHistory: + question: str = "" + ideas: List[str] = [] + critique: str = "" + + @property + def n_ideas(self) -> int: + return len(self.ideas) + + def ideation_prompt_inputs(self) -> Dict[str, Any]: + return {"question": self.question} + + def critique_prompt_inputs(self) -> Dict[str, Any]: + return { + "question": self.question, + **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)}, + } + + def resolve_prompt_inputs(self) -> Dict[str, Any]: + return { + "question": self.question, + **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)}, + "critique": self.critique, + } + + prompt: BasePromptTemplate + """Prompt object to use.""" + ideation_llm: Optional[BaseLanguageModel] = None + """LLM to use in ideation step. If None given, 'llm' will be used.""" + critique_llm: Optional[BaseLanguageModel] = None + """LLM to use in critique step. If None given, 'llm' will be used.""" + resolver_llm: Optional[BaseLanguageModel] = None + """LLM to use in resolve step. If None given, 'llm' will be used.""" + llm: Optional[BaseLanguageModel] = None + """LLM to use for each steps, if no specific llm for that step is given. """ + n_ideas: int = 3 + """Number of ideas to generate in idea step""" + return_intermediate_steps: bool = False + """Whether to return ideas and critique, in addition to resolution.""" + history: SmartLLMChainHistory = SmartLLMChainHistory() + + class Config: + extra = Extra.forbid + + @root_validator + @classmethod + def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Ensure we have an LLM for each step.""" + llm = values.get("llm") + ideation_llm = values.get("ideation_llm") + critique_llm = values.get("critique_llm") + resolver_llm = values.get("resolver_llm") + + if not llm and not ideation_llm: + raise ValueError( + "Either ideation_llm or llm needs to be given. Pass llm, " + "if you want to use the same llm for all steps, or pass " + "ideation_llm, critique_llm and resolver_llm if you want " + "to use different llms for each step." + ) + if not llm and not critique_llm: + raise ValueError( + "Either critique_llm or llm needs to be given. Pass llm, " + "if you want to use the same llm for all steps, or pass " + "ideation_llm, critique_llm and resolver_llm if you want " + "to use different llms for each step." + ) + if not llm and not resolver_llm: + raise ValueError( + "Either resolve_llm or llm needs to be given. Pass llm, " + "if you want to use the same llm for all steps, or pass " + "ideation_llm, critique_llm and resolver_llm if you want " + "to use different llms for each step." + ) + if llm and ideation_llm and critique_llm and resolver_llm: + raise ValueError( + "LLMs are given for each step (ideation_llm, critique_llm," + " resolver_llm), but backup LLM (llm) is also given, which" + " would not be used." + ) + return values + + @property + def input_keys(self) -> List[str]: + """Defines the input keys.""" + return self.prompt.input_variables + + @property + def output_keys(self) -> List[str]: + """Defines the output keys.""" + if self.return_intermediate_steps: + return ["ideas", "critique", "resolution"] + return ["resolution"] + + def prep_prompts( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Tuple[PromptValue, Optional[List[str]]]: + """Prepare prompts from inputs.""" + stop = None + if "stop" in inputs: + stop = inputs["stop"] + selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} + prompt = self.prompt.format_prompt(**selected_inputs) + _colored_text = get_colored_text(prompt.to_string(), "green") + _text = "Prompt after formatting:\n" + _colored_text + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) + if "stop" in inputs and inputs["stop"] != stop: + raise ValueError( + "If `stop` is present in any inputs, should be present in all." + ) + return prompt, stop + + def _call( + self, + input_list: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + prompt, stop = self.prep_prompts(input_list, run_manager=run_manager) + self.history.question = prompt.to_string() + ideas = self._ideate(stop, run_manager) + self.history.ideas = ideas + critique = self._critique(stop, run_manager) + self.history.critique = critique + resolution = self._resolve(stop, run_manager) + if self.return_intermediate_steps: + return {"ideas": ideas, "critique": critique, "resolution": resolution} + return {"resolution": resolution} + + def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str: + """Between steps, only the LLM result text is passed, not the LLMResult object. + This function extracts the text from an LLMResult.""" + if len(result.generations) != 1: + raise ValueError( + f"In SmartLLM the LLM result in step {step} is not " + "exactly 1 element. This should never happen" + ) + if len(result.generations[0]) != 1: + raise ValueError( + f"In SmartLLM the LLM in step {step} returned more than " + "1 output. SmartLLM only works with LLMs returning " + "exactly 1 output." + ) + return result.generations[0][0].text + + def get_prompt_strings( + self, stage: str + ) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]: + role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = [] + role_strings.append( + ( + HumanMessagePromptTemplate, + "Question: {question}\nAnswer: Let's work this out in a step by " + "step way to be sure we have the right answer:", + ) + ) + if stage == "ideation": + return role_strings + role_strings.extend( + [ + *[ + ( + AIMessagePromptTemplate, + "Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}", + ) + for i in range(self.n_ideas) + ], + ( + HumanMessagePromptTemplate, + "You are a researcher tasked with investigating the " + f"{self.n_ideas} response options provided. List the flaws and " + "faulty logic of each answer options. Let'w work this out in a step" + " by step way to be sure we have all the errors:", + ), + ] + ) + if stage == "critique": + return role_strings + role_strings.extend( + [ + (AIMessagePromptTemplate, "Critique: {critique}"), + ( + HumanMessagePromptTemplate, + "You are a resolved tasked with 1) finding which of " + f"the {self.n_ideas} anwer options the researcher thought was " + "best,2) improving that answer and 3) printing the answer in full. " + "Don't output anything for step 1 or 2, only the full answer in 3. " + "Let's work this out in a step by step way to be sure we have " + "the right answer:", + ), + ] + ) + if stage == "resolve": + return role_strings + raise ValueError( + "stage should be either 'ideation', 'critique' or 'resolve'," + f" but it is '{stage}'. This should never happen." + ) + + def ideation_prompt(self) -> ChatPromptTemplate: + return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation")) + + def critique_prompt(self) -> ChatPromptTemplate: + return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique")) + + def resolve_prompt(self) -> ChatPromptTemplate: + return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve")) + + def _ideate( + self, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> List[str]: + """Generate n_ideas ideas as response to user prompt.""" + llm = self.ideation_llm if self.ideation_llm else self.llm + prompt = self.ideation_prompt().format_prompt( + **self.history.ideation_prompt_inputs() + ) + callbacks = run_manager.get_child() if run_manager else None + if llm: + ideas = [ + self._get_text_from_llm_result( + llm.generate_prompt([prompt], stop, callbacks), + step="ideate", + ) + for _ in range(self.n_ideas) + ] + for i, idea in enumerate(ideas): + _colored_text = get_colored_text(idea, "blue") + _text = f"Idea {i+1}:\n" + _colored_text + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) + return ideas + else: + raise ValueError("llm is none, which should never happen") + + def _critique( + self, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> str: + """Critique each of the ideas from ideation stage & select best one.""" + llm = self.critique_llm if self.critique_llm else self.llm + prompt = self.critique_prompt().format_prompt( + **self.history.critique_prompt_inputs() + ) + callbacks = run_manager.handlers if run_manager else None + if llm: + critique = self._get_text_from_llm_result( + llm.generate_prompt([prompt], stop, callbacks), step="critique" + ) + _colored_text = get_colored_text(critique, "yellow") + _text = "Critique:\n" + _colored_text + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) + return critique + else: + raise ValueError("llm is none, which should never happen") + + def _resolve( + self, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> str: + """Improve upon the best idea as chosen in critique step & return it.""" + llm = self.resolver_llm if self.resolver_llm else self.llm + prompt = self.resolve_prompt().format_prompt( + **self.history.resolve_prompt_inputs() + ) + callbacks = run_manager.handlers if run_manager else None + if llm: + resolution = self._get_text_from_llm_result( + llm.generate_prompt([prompt], stop, callbacks), step="resolve" + ) + _colored_text = get_colored_text(resolution, "green") + _text = "Resolution:\n" + _colored_text + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) + return resolution + else: + raise ValueError("llm is none, which should never happen") diff --git a/libs/experimental/tests/unit_tests/test_smartllm.py b/libs/experimental/tests/unit_tests/test_smartllm.py new file mode 100644 index 00000000000..b969bbdb5db --- /dev/null +++ b/libs/experimental/tests/unit_tests/test_smartllm.py @@ -0,0 +1,120 @@ +"""Test SmartLLM.""" +from langchain.chat_models import FakeListChatModel +from langchain.llms import FakeListLLM +from langchain.prompts.prompt import PromptTemplate + +from langchain_experimental.smart_llm import SmartLLMChain + + +def test_ideation() -> None: + # test that correct responses are returned + responses = ["Idea 1", "Idea 2", "Idea 3"] + llm = FakeListLLM(responses=responses) + prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", + ) + chain = SmartLLMChain(llm=llm, prompt=prompt) + prompt_value, _ = chain.prep_prompts({"product": "socks"}) + chain.history.question = prompt_value.to_string() + results = chain._ideate() + assert results == responses + + # test that correct number of responses are returned + for i in range(1, 5): + responses = [f"Idea {j+1}" for j in range(i)] + llm = FakeListLLM(responses=responses) + chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=i) + prompt_value, _ = chain.prep_prompts({"product": "socks"}) + chain.history.question = prompt_value.to_string() + results = chain._ideate() + assert len(results) == i + + +def test_critique() -> None: + response = "Test Critique" + llm = FakeListLLM(responses=[response]) + prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", + ) + chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2) + prompt_value, _ = chain.prep_prompts({"product": "socks"}) + chain.history.question = prompt_value.to_string() + chain.history.ideas = ["Test Idea 1", "Test Idea 2"] + result = chain._critique() + assert result == response + + +def test_resolver() -> None: + response = "Test resolution" + llm = FakeListLLM(responses=[response]) + prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", + ) + chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2) + prompt_value, _ = chain.prep_prompts({"product": "socks"}) + chain.history.question = prompt_value.to_string() + chain.history.ideas = ["Test Idea 1", "Test Idea 2"] + chain.history.critique = "Test Critique" + result = chain._resolve() + assert result == response + + +def test_all_steps() -> None: + joke = "Why did the chicken cross the Mobius strip?" + response = "Resolution response" + ideation_llm = FakeListLLM(responses=["Ideation response" for _ in range(20)]) + critique_llm = FakeListLLM(responses=["Critique response" for _ in range(20)]) + resolver_llm = FakeListLLM(responses=[response for _ in range(20)]) + prompt = PromptTemplate( + input_variables=["joke"], + template="Explain this joke to me: {joke}?", + ) + chain = SmartLLMChain( + ideation_llm=ideation_llm, + critique_llm=critique_llm, + resolver_llm=resolver_llm, + prompt=prompt, + ) + result = chain(joke) + assert result["joke"] == joke + assert result["resolution"] == response + + +def test_intermediate_output() -> None: + joke = "Why did the chicken cross the Mobius strip?" + llm = FakeListLLM(responses=[f"Response {i+1}" for i in range(5)]) + prompt = PromptTemplate( + input_variables=["joke"], + template="Explain this joke to me: {joke}?", + ) + chain = SmartLLMChain(llm=llm, prompt=prompt, return_intermediate_steps=True) + result = chain(joke) + assert result["joke"] == joke + assert result["ideas"] == [f"Response {i+1}" for i in range(3)] + assert result["critique"] == "Response 4" + assert result["resolution"] == "Response 5" + + +def test_all_steps_with_chat_model() -> None: + joke = "Why did the chicken cross the Mobius strip?" + response = "Resolution response" + + ideation_llm = FakeListChatModel(responses=["Ideation response" for _ in range(20)]) + critique_llm = FakeListChatModel(responses=["Critique response" for _ in range(20)]) + resolver_llm = FakeListChatModel(responses=[response for _ in range(20)]) + prompt = PromptTemplate( + input_variables=["joke"], + template="Explain this joke to me: {joke}?", + ) + chain = SmartLLMChain( + ideation_llm=ideation_llm, + critique_llm=critique_llm, + resolver_llm=resolver_llm, + prompt=prompt, + ) + result = chain(joke) + assert result["joke"] == joke + assert result["resolution"] == response