recalc tokens

This commit is contained in:
Ankush Gola
2023-06-07 22:41:05 -07:00
parent 0197abf2b3
commit 960d0f6646
2 changed files with 32 additions and 14 deletions

View File

@@ -108,6 +108,34 @@ class BaseLLM(BaseLanguageModel, ABC):
else:
return verbose
def _flatten_llm_result(self, prompts: List[str], result: LLMResult) -> List[LLMResult]:
"""Flatten the LLMResult into a list of LLMResults for batched runs."""
if len(result.generations) != len(prompts):
raise ValueError(
f"Expected {len(prompts)} generations, got {len(result.generations)}"
)
llm_outputs = []
for prompt, gens in zip(prompts, result.generations):
try:
llm_output = {
"completion_tokens": self.get_num_tokens("".join([gen.text for gen in gens])),
"prompt_tokens": self.get_num_tokens(prompt),
}
llm_output["total_tokens"] = (
llm_output["completion_tokens"] + llm_output["prompt_tokens"]
)
except ImportError:
llm_output = None
llm_outputs.append(llm_output)
return [
LLMResult(
generations=[gen],
llm_output=llm_output,
)
for gen, llm_output in zip(result.generations, llm_outputs)
]
@abstractmethod
def _generate(
self,
@@ -196,7 +224,7 @@ class BaseLLM(BaseLanguageModel, ABC):
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = output.flatten()
flattened_outputs = self._flatten_llm_result(prompts, output)
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
@@ -224,7 +252,7 @@ class BaseLLM(BaseLanguageModel, ABC):
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = new_results.flatten()
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
llm_output = update_cache(
@@ -285,7 +313,7 @@ class BaseLLM(BaseLanguageModel, ABC):
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = output.flatten()
flattened_outputs = self._flatten_llm_result(prompts, output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
@@ -318,7 +346,7 @@ class BaseLLM(BaseLanguageModel, ABC):
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = new_results.flatten()
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)

View File

@@ -185,16 +185,6 @@ class LLMResult(BaseModel):
run: Optional[List[RunInfo]] = None
"""Run metadata."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list."""
return [
LLMResult(
generations=[gen],
llm_output=self.llm_output,
)
for gen in self.generations
]
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented