From 8af35341709a5b047f8773c99dd25caf3860d421 Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Sat, 10 Jun 2023 15:40:31 -0700 Subject: [PATCH] avoid double counting for openai callback --- langchain/callbacks/openai_info.py | 58 ------------------------------ langchain/llms/base.py | 46 ++---------------------- langchain/schema.py | 26 ++++++++++++++ 3 files changed, 28 insertions(+), 102 deletions(-) diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index e8d2f3b6463..bbe238686d8 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -110,64 +110,6 @@ class OpenAICallbackHandler(BaseCallbackHandler): self.prompt_tokens += prompt_tokens self.completion_tokens += completion_tokens - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Print out that we are entering a chain.""" - pass - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Print out that we finished a chain.""" - pass - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - **kwargs: Any, - ) -> None: - """Print out the log in specified color.""" - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - pass - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Do nothing.""" - pass - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - pass - - def on_agent_finish( - self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any - ) -> None: - """Run on agent end.""" - pass - def __copy__(self) -> "OpenAICallbackHandler": """Return a copy of the callback handler.""" return self diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 83b02e7251d..884c5c9c73b 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -108,44 +108,6 @@ class BaseLLM(BaseLanguageModel, ABC): else: return verbose - def _flatten_llm_result( - self, prompts: List[str], result: LLMResult - ) -> List[LLMResult]: - """Flatten the LLMResult into a list of LLMResults for batched runs.""" - if len(result.generations) != len(prompts): - raise ValueError( - f"Expected {len(prompts)} generations, got {len(result.generations)}" - ) - - llm_outputs = [] - for prompt, gens in zip(prompts, result.generations): - try: - token_usage = { - "completion_tokens": self.get_num_tokens( - "".join([gen.text for gen in gens]) - ), - "prompt_tokens": self.get_num_tokens(prompt), - } - token_usage["total_tokens"] = ( - token_usage["completion_tokens"] + token_usage["prompt_tokens"] - ) - llm_output = { - "token_usage": token_usage, - } - if result.llm_output and result.llm_output["model_name"]: - llm_output["model_name"] = result.llm_output["model_name"] - - except ImportError: - llm_output = None - llm_outputs.append(llm_output) - return [ - LLMResult( - generations=[gen], - llm_output=llm_output, - ) - for gen, llm_output in zip(result.generations, llm_outputs) - ] - @abstractmethod def _generate( self, @@ -203,9 +165,7 @@ class BaseLLM(BaseLanguageModel, ABC): for run_manager in run_managers: run_manager.on_llm_error(e) raise e - flattened_outputs = ( - self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output] - ) + flattened_outputs = output.flatten() for manager, flattened_output in zip(run_managers, flattened_outputs): manager.on_llm_end(flattened_output) if run_managers: @@ -298,9 +258,7 @@ class BaseLLM(BaseLanguageModel, ABC): *[run_manager.on_llm_error(e) for run_manager in run_managers] ) raise e - flattened_outputs = ( - self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output] - ) + flattened_outputs = output.flatten() await asyncio.gather( *[ run_manager.on_llm_end(flattened_output) diff --git a/langchain/schema.py b/langchain/schema.py index a7240c01292..084806fb30c 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -185,6 +185,32 @@ class LLMResult(BaseModel): run: Optional[List[RunInfo]] = None """Run metadata.""" + def flatten(self) -> List[LLMResult]: + """Flatten generations into a single list.""" + llm_results = [] + for i, gen_list in enumerate(self.generations): + # Avoid double counting tokens in OpenAICallback + if i == 0: + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=self.llm_output, + ) + ) + else: + if self.llm_output is not None: + llm_output = self.llm_output.copy() + llm_output["token_usage"] = None + else: + llm_output = None + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=llm_output, + ) + ) + return llm_results + def __eq__(self, other: object) -> bool: if not isinstance(other, LLMResult): return NotImplemented