mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
fix caching (#555)
This commit is contained in:
parent
74932f2516
commit
9833fcfe32
@ -109,7 +109,7 @@ class BaseLLM(BaseModel, ABC):
|
|||||||
self.callback_manager.on_llm_end(new_results)
|
self.callback_manager.on_llm_end(new_results)
|
||||||
for i, result in enumerate(new_results.generations):
|
for i, result in enumerate(new_results.generations):
|
||||||
existing_prompts[missing_prompt_idxs[i]] = result
|
existing_prompts[missing_prompt_idxs[i]] = result
|
||||||
prompt = prompts[i]
|
prompt = prompts[missing_prompt_idxs[i]]
|
||||||
langchain.llm_cache.update(prompt, llm_string, result)
|
langchain.llm_cache.update(prompt, llm_string, result)
|
||||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||||
return LLMResult(generations=generations, llm_output=new_results.llm_output)
|
return LLMResult(generations=generations, llm_output=new_results.llm_output)
|
||||||
|
@ -14,6 +14,9 @@ def test_caching() -> None:
|
|||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo", "bar", "foo"])
|
output = llm.generate(["foo", "bar", "foo"])
|
||||||
|
expected_cache_output = [Generation(text="foo")]
|
||||||
|
cache_output = langchain.llm_cache.lookup("bar", llm_string)
|
||||||
|
assert cache_output == expected_cache_output
|
||||||
langchain.llm_cache = None
|
langchain.llm_cache = None
|
||||||
expected_generations = [
|
expected_generations = [
|
||||||
[Generation(text="fizz")],
|
[Generation(text="fizz")],
|
||||||
|
Loading…
Reference in New Issue
Block a user