mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
avoid double counting for openai callback
This commit is contained in:
@@ -110,64 +110,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
|
||||
def __copy__(self) -> "OpenAICallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
@@ -108,44 +108,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
else:
|
||||
return verbose
|
||||
|
||||
def _flatten_llm_result(
|
||||
self, prompts: List[str], result: LLMResult
|
||||
) -> List[LLMResult]:
|
||||
"""Flatten the LLMResult into a list of LLMResults for batched runs."""
|
||||
if len(result.generations) != len(prompts):
|
||||
raise ValueError(
|
||||
f"Expected {len(prompts)} generations, got {len(result.generations)}"
|
||||
)
|
||||
|
||||
llm_outputs = []
|
||||
for prompt, gens in zip(prompts, result.generations):
|
||||
try:
|
||||
token_usage = {
|
||||
"completion_tokens": self.get_num_tokens(
|
||||
"".join([gen.text for gen in gens])
|
||||
),
|
||||
"prompt_tokens": self.get_num_tokens(prompt),
|
||||
}
|
||||
token_usage["total_tokens"] = (
|
||||
token_usage["completion_tokens"] + token_usage["prompt_tokens"]
|
||||
)
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
if result.llm_output and result.llm_output["model_name"]:
|
||||
llm_output["model_name"] = result.llm_output["model_name"]
|
||||
|
||||
except ImportError:
|
||||
llm_output = None
|
||||
llm_outputs.append(llm_output)
|
||||
return [
|
||||
LLMResult(
|
||||
generations=[gen],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
for gen, llm_output in zip(result.generations, llm_outputs)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
@@ -203,9 +165,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = (
|
||||
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
|
||||
)
|
||||
flattened_outputs = output.flatten()
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
@@ -298,9 +258,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = (
|
||||
self._flatten_llm_result(prompts, output) if len(prompts) > 1 else [output]
|
||||
)
|
||||
flattened_outputs = output.flatten()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
|
||||
@@ -185,6 +185,32 @@ class LLMResult(BaseModel):
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""Run metadata."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list."""
|
||||
llm_results = []
|
||||
for i, gen_list in enumerate(self.generations):
|
||||
# Avoid double counting tokens in OpenAICallback
|
||||
if i == 0:
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.llm_output is not None:
|
||||
llm_output = self.llm_output.copy()
|
||||
llm_output["token_usage"] = None
|
||||
else:
|
||||
llm_output = None
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
|
||||
Reference in New Issue
Block a user