avoid double counting for openai callback

This commit is contained in:
Ankush Gola
2023-06-10 15:40:31 -07:00
parent 1724366b12
commit 8af3534170
3 changed files with 28 additions and 102 deletions

View File

@@ -110,64 +110,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self

View File

@@ -108,44 +108,6 @@ class BaseLLM(BaseLanguageModel, ABC):
else:
return verbose
def _flatten_llm_result(
self, prompts: List[str], result: LLMResult
) -> List[LLMResult]:
"""Flatten the LLMResult into a list of LLMResults for batched runs."""
if len(result.generations) != len(prompts):
raise ValueError(
f"Expected {len(prompts)} generations, got {len(result.generations)}"
)
llm_outputs = []
for prompt, gens in zip(prompts, result.generations):
try:
token_usage = {
"completion_tokens": self.get_num_tokens(
"".join([gen.text for gen in gens])
),
"prompt_tokens": self.get_num_tokens(prompt),
}
token_usage["total_tokens"] = (
token_usage["completion_tokens"] + token_usage["prompt_tokens"]
)
llm_output = {
"token_usage": token_usage,
}
if result.llm_output and result.llm_output["model_name"]:
llm_output["model_name"] = result.llm_output["model_name"]
except ImportError:
llm_output = None
llm_outputs.append(llm_output)
return [
LLMResult(
generations=[gen],
llm_output=llm_output,
)
for gen, llm_output in zip(result.generations, llm_outputs)
]
@abstractmethod
def _generate(
self,
@@ -203,9 +165,7 @@ class BaseLLM(BaseLanguageModel, ABC):
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = (
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
)
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
@@ -298,9 +258,7 @@ class BaseLLM(BaseLanguageModel, ABC):
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = (
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
)
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)

View File

@@ -185,6 +185,32 @@ class LLMResult(BaseModel):
run: Optional[List[RunInfo]] = None
"""Run metadata."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list."""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = self.llm_output.copy()
llm_output["token_usage"] = None
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented