Harrison/fix caching bug (#788)

Co-authored-by: thepok <richterthepok@yahoo.de>
This commit is contained in:
Harrison Chase 2023-01-28 14:24:30 -08:00 committed by GitHub
parent 248c297f1b
commit 5f73d06502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,21 +92,25 @@ class BaseLLM(BaseModel, ABC):
else: else:
missing_prompts.append(prompt) missing_prompts.append(prompt)
missing_prompt_idxs.append(i) missing_prompt_idxs.append(i)
self.callback_manager.on_llm_start( if len(missing_prompts) > 0:
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose self.callback_manager.on_llm_start(
) {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
try: )
new_results = self._generate(missing_prompts, stop=stop) try:
except (KeyboardInterrupt, Exception) as e: new_results = self._generate(missing_prompts, stop=stop)
self.callback_manager.on_llm_error(e, verbose=self.verbose) except (KeyboardInterrupt, Exception) as e:
raise e self.callback_manager.on_llm_error(e, verbose=self.verbose)
self.callback_manager.on_llm_end(new_results, verbose=self.verbose) raise e
for i, result in enumerate(new_results.generations): self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
existing_prompts[missing_prompt_idxs[i]] = result for i, result in enumerate(new_results.generations):
prompt = prompts[missing_prompt_idxs[i]] existing_prompts[missing_prompt_idxs[i]] = result
langchain.llm_cache.update(prompt, llm_string, result) prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=new_results.llm_output) return LLMResult(generations=generations, llm_output=llm_output)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.""" """Get the number of tokens present in the text."""