From 960d0f6646e02ccba8e1ec388a5b920de2bc6a88 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Wed, 7 Jun 2023 22:41:05 -0700 Subject: [PATCH] recalc tokens --- langchain/llms/base.py | 36 ++++++++++++++++++++++++++++++++---- langchain/schema.py | 10 ---------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index b6efa5e8a17..f5e76e310aa 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -108,6 +108,34 @@ class BaseLLM(BaseLanguageModel, ABC): else: return verbose + def _flatten_llm_result(self, prompts: List[str], result: LLMResult) -> List[LLMResult]: + """Flatten the LLMResult into a list of LLMResults for batched runs.""" + if len(result.generations) != len(prompts): + raise ValueError( + f"Expected {len(prompts)} generations, got {len(result.generations)}" + ) + + llm_outputs = [] + for prompt, gens in zip(prompts, result.generations): + try: + llm_output = { + "completion_tokens": self.get_num_tokens("".join([gen.text for gen in gens])), + "prompt_tokens": self.get_num_tokens(prompt), + } + llm_output["total_tokens"] = ( + llm_output["completion_tokens"] + llm_output["prompt_tokens"] + ) + except ImportError: + llm_output = None + llm_outputs.append(llm_output) + return [ + LLMResult( + generations=[gen], + llm_output=llm_output, + ) + for gen, llm_output in zip(result.generations, llm_outputs) + ] + @abstractmethod def _generate( self, @@ -196,7 +224,7 @@ class BaseLLM(BaseLanguageModel, ABC): for run_manager in run_managers: run_manager.on_llm_error(e) raise e - flattened_outputs = output.flatten() + flattened_outputs = self._flatten_llm_result(prompts, output) for manager, flattened_output in zip(run_managers, flattened_outputs): manager.on_llm_end(flattened_output) if run_managers: @@ -224,7 +252,7 @@ class BaseLLM(BaseLanguageModel, ABC): for run_manager in run_managers: run_manager.on_llm_error(e) raise e - flattened_outputs = new_results.flatten() + flattened_outputs = self._flatten_llm_result(missing_prompts, new_results) for manager, flattened_output in zip(run_managers, flattened_outputs): manager.on_llm_end(flattened_output) llm_output = update_cache( @@ -285,7 +313,7 @@ class BaseLLM(BaseLanguageModel, ABC): *[run_manager.on_llm_error(e) for run_manager in run_managers] ) raise e - flattened_outputs = output.flatten() + flattened_outputs = self._flatten_llm_result(prompts, output) await asyncio.gather( *[ run_manager.on_llm_end(flattened_output) @@ -318,7 +346,7 @@ class BaseLLM(BaseLanguageModel, ABC): *[run_manager.on_llm_error(e) for run_manager in run_managers] ) raise e - flattened_outputs = new_results.flatten() + flattened_outputs = self._flatten_llm_result(missing_prompts, new_results) await asyncio.gather( *[ run_manager.on_llm_end(flattened_output) diff --git a/langchain/schema.py b/langchain/schema.py index 005b96e6b8e..a7240c01292 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -185,16 +185,6 @@ class LLMResult(BaseModel): run: Optional[List[RunInfo]] = None """Run metadata.""" - def flatten(self) -> List[LLMResult]: - """Flatten generations into a single list.""" - return [ - LLMResult( - generations=[gen], - llm_output=self.llm_output, - ) - for gen in self.generations - ] - def __eq__(self, other: object) -> bool: if not isinstance(other, LLMResult): return NotImplemented