mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
recalc tokens
This commit is contained in:
@@ -108,6 +108,34 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
else:
|
||||
return verbose
|
||||
|
||||
def _flatten_llm_result(self, prompts: List[str], result: LLMResult) -> List[LLMResult]:
|
||||
"""Flatten the LLMResult into a list of LLMResults for batched runs."""
|
||||
if len(result.generations) != len(prompts):
|
||||
raise ValueError(
|
||||
f"Expected {len(prompts)} generations, got {len(result.generations)}"
|
||||
)
|
||||
|
||||
llm_outputs = []
|
||||
for prompt, gens in zip(prompts, result.generations):
|
||||
try:
|
||||
llm_output = {
|
||||
"completion_tokens": self.get_num_tokens("".join([gen.text for gen in gens])),
|
||||
"prompt_tokens": self.get_num_tokens(prompt),
|
||||
}
|
||||
llm_output["total_tokens"] = (
|
||||
llm_output["completion_tokens"] + llm_output["prompt_tokens"]
|
||||
)
|
||||
except ImportError:
|
||||
llm_output = None
|
||||
llm_outputs.append(llm_output)
|
||||
return [
|
||||
LLMResult(
|
||||
generations=[gen],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
for gen, llm_output in zip(result.generations, llm_outputs)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
@@ -196,7 +224,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = output.flatten()
|
||||
flattened_outputs = self._flatten_llm_result(prompts, output)
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
@@ -224,7 +252,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = new_results.flatten()
|
||||
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
llm_output = update_cache(
|
||||
@@ -285,7 +313,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = output.flatten()
|
||||
flattened_outputs = self._flatten_llm_result(prompts, output)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
@@ -318,7 +346,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = new_results.flatten()
|
||||
flattened_outputs = self._flatten_llm_result(missing_prompts, new_results)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
|
||||
@@ -185,16 +185,6 @@ class LLMResult(BaseModel):
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""Run metadata."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list."""
|
||||
return [
|
||||
LLMResult(
|
||||
generations=[gen],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
for gen in self.generations
|
||||
]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
|
||||
Reference in New Issue
Block a user