mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
community[patch]: gather token usage info in BedrockChat during generation (#19127)
This PR allows to calculate token usage for prompts and completion directly in the generation method of BedrockChat. The token usage details are then returned together with the generations, so that other downstream tasks can access them easily. This allows to define a callback for tokens tracking and cost calculation, similarly to what happens with OpenAI (see [OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler). I plan on adding a BedrockCallbackHandler later. Right now keeping track of tokens in the callback is already possible, but it requires passing the llm, as done here: https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html. However, I find the approach of this PR cleaner. Thanks for your reviews. FYI @baskaryan, @hwchase17 --------- Co-authored-by: taamedag <Davide.Menini@swisscom.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@@ -234,10 +235,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
provider = self._get_provider()
|
||||
system = None
|
||||
formatted_messages = None
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
|
||||
if provider == "anthropic":
|
||||
prompt = None
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
@@ -265,17 +265,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
completion = ""
|
||||
llm_output: Dict[str, Any] = {"model_id": self.model_id}
|
||||
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
provider = self._get_provider()
|
||||
system = None
|
||||
formatted_messages = None
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
|
||||
if provider == "anthropic":
|
||||
prompt = None
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
@@ -287,7 +287,7 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
completion = self._prepare_input_and_invoke(
|
||||
completion, usage_info = self._prepare_input_and_invoke(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
@@ -296,10 +296,25 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**params,
|
||||
)
|
||||
|
||||
llm_output["usage"] = usage_info
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=completion))]
|
||||
generations=[ChatGeneration(message=AIMessage(content=completion))],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
final_usage: Dict[str, int] = defaultdict(int)
|
||||
final_output = {}
|
||||
for output in llm_outputs:
|
||||
output = output or {}
|
||||
usage = output.pop("usage", {})
|
||||
for token_type, token_count in usage.items():
|
||||
final_usage[token_type] += token_count
|
||||
final_output.update(output)
|
||||
final_output["usage"] = final_usage
|
||||
return final_output
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
|
Reference in New Issue
Block a user