diff --git a/langchain/llms/base.py b/langchain/llms/base.py index b8d05ccd284..97c062f3bfc 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -109,7 +109,7 @@ class BaseLLM(BaseModel, ABC): self.callback_manager.on_llm_end(new_results) for i, result in enumerate(new_results.generations): existing_prompts[missing_prompt_idxs[i]] = result - prompt = prompts[i] + prompt = prompts[missing_prompt_idxs[i]] langchain.llm_cache.update(prompt, llm_string, result) generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=new_results.llm_output) diff --git a/tests/unit_tests/llms/test_base.py b/tests/unit_tests/llms/test_base.py index da67a9da85e..05f97d6efc6 100644 --- a/tests/unit_tests/llms/test_base.py +++ b/tests/unit_tests/llms/test_base.py @@ -14,6 +14,9 @@ def test_caching() -> None: llm_string = str(sorted([(k, v) for k, v in params.items()])) langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")]) output = llm.generate(["foo", "bar", "foo"]) + expected_cache_output = [Generation(text="foo")] + cache_output = langchain.llm_cache.lookup("bar", llm_string) + assert cache_output == expected_cache_output langchain.llm_cache = None expected_generations = [ [Generation(text="fizz")],