community[patch]: gather token usage info in BedrockChat during generation (#19127)

This PR allows to calculate token usage for prompts and completion
directly in the generation method of BedrockChat. The token usage
details are then returned together with the generations, so that other
downstream tasks can access them easily.

This allows to define a callback for tokens tracking and cost
calculation, similarly to what happens with OpenAI (see
[OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler).
I plan on adding a BedrockCallbackHandler later.
Right now keeping track of tokens in the callback is already possible,
but it requires passing the llm, as done here:
https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html.
However, I find the approach of this PR cleaner.

Thanks for your reviews. FYI @baskaryan, @hwchase17

---------

Co-authored-by: taamedag <Davide.Menini@swisscom.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Davide Menini
2024-03-28 19:58:46 +01:00
committed by GitHub
parent a662468dde
commit f7042321f1
3 changed files with 65 additions and 15 deletions

View File

@@ -1,4 +1,5 @@
import re
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from langchain_core.callbacks import (
@@ -234,10 +235,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
system = None
formatted_messages = None
prompt, system, formatted_messages = None, None, None
if provider == "anthropic":
prompt = None
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
@@ -265,17 +265,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
system = None
formatted_messages = None
prompt, system, formatted_messages = None, None, None
params: Dict[str, Any] = {**kwargs}
if provider == "anthropic":
prompt = None
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
@@ -287,7 +287,7 @@ class BedrockChat(BaseChatModel, BedrockBase):
if stop:
params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
completion, usage_info = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
@@ -296,10 +296,25 @@ class BedrockChat(BaseChatModel, BedrockBase):
**params,
)
llm_output["usage"] = usage_info
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=completion))]
generations=[ChatGeneration(message=AIMessage(content=completion))],
llm_output=llm_output,
)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_usage: Dict[str, int] = defaultdict(int)
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.pop("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
final_output["usage"] = final_usage
return final_output
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)