From f7042321f1048801d3011ab5e7bdc3b570027de4 Mon Sep 17 00:00:00 2001 From: Davide Menini <48685774+dmenini@users.noreply.github.com> Date: Thu, 28 Mar 2024 19:58:46 +0100 Subject: [PATCH] community[patch]: gather token usage info in BedrockChat during generation (#19127) This PR allows to calculate token usage for prompts and completion directly in the generation method of BedrockChat. The token usage details are then returned together with the generations, so that other downstream tasks can access them easily. This allows to define a callback for tokens tracking and cost calculation, similarly to what happens with OpenAI (see [OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler). I plan on adding a BedrockCallbackHandler later. Right now keeping track of tokens in the callback is already possible, but it requires passing the llm, as done here: https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html. However, I find the approach of this PR cleaner. Thanks for your reviews. FYI @baskaryan, @hwchase17 --------- Co-authored-by: taamedag Co-authored-by: Bagatur --- .../chat_models/bedrock.py | 31 ++++++++++++++----- .../langchain_community/llms/bedrock.py | 19 +++++++++--- .../chat_models/test_bedrock.py | 30 ++++++++++++++++-- 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index 933343d6a96..4cb73455f2d 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -1,4 +1,5 @@ import re +from collections import defaultdict from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from langchain_core.callbacks import ( @@ -234,10 +235,9 @@ class BedrockChat(BaseChatModel, BedrockBase): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: provider = self._get_provider() - system = None - formatted_messages = None + prompt, system, formatted_messages = None, None, None + if provider == "anthropic": - prompt = None system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) @@ -265,17 +265,17 @@ class BedrockChat(BaseChatModel, BedrockBase): **kwargs: Any, ) -> ChatResult: completion = "" + llm_output: Dict[str, Any] = {"model_id": self.model_id} if self.streaming: for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text else: provider = self._get_provider() - system = None - formatted_messages = None + prompt, system, formatted_messages = None, None, None params: Dict[str, Any] = {**kwargs} + if provider == "anthropic": - prompt = None system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) @@ -287,7 +287,7 @@ class BedrockChat(BaseChatModel, BedrockBase): if stop: params["stop_sequences"] = stop - completion = self._prepare_input_and_invoke( + completion, usage_info = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, @@ -296,10 +296,25 @@ class BedrockChat(BaseChatModel, BedrockBase): **params, ) + llm_output["usage"] = usage_info + return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=completion))] + generations=[ChatGeneration(message=AIMessage(content=completion))], + llm_output=llm_output, ) + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + final_usage: Dict[str, int] = defaultdict(int) + final_output = {} + for output in llm_outputs: + output = output or {} + usage = output.pop("usage", {}) + for token_type, token_count in usage.items(): + final_usage[token_type] += token_count + final_output.update(output) + final_output["usage"] = final_usage + return final_output + def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: return get_num_tokens_anthropic(text) diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index 9b7515a5f4d..8ab5ad276ea 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -11,6 +11,7 @@ from typing import ( List, Mapping, Optional, + Tuple, ) from langchain_core.callbacks import ( @@ -141,6 +142,7 @@ class LLMInputOutputAdapter: @classmethod def prepare_output(cls, provider: str, response: Any) -> dict: + text = "" if provider == "anthropic": response_body = json.loads(response.get("body").read().decode()) if "completion" in response_body: @@ -162,9 +164,17 @@ class LLMInputOutputAdapter: else: text = response_body.get("results")[0].get("outputText") + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) return { "text": text, "body": response_body, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, } @classmethod @@ -498,7 +508,7 @@ class BedrockBase(BaseModel, ABC): stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: + ) -> Tuple[str, Dict[str, Any]]: _model_kwargs = self.model_kwargs or {} provider = self._get_provider() @@ -531,7 +541,7 @@ class BedrockBase(BaseModel, ABC): try: response = self.client.invoke_model(**request_options) - text, body = LLMInputOutputAdapter.prepare_output( + text, body, usage_info = LLMInputOutputAdapter.prepare_output( provider, response ).values() @@ -554,7 +564,7 @@ class BedrockBase(BaseModel, ABC): **services_trace, ) - return text + return text, usage_info def _get_bedrock_services_signal(self, body: dict) -> dict: """ @@ -824,9 +834,10 @@ class Bedrock(LLM, BedrockBase): completion += chunk.text return completion - return self._prepare_input_and_invoke( + text, _ = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ) + return text async def _astream( self, diff --git a/libs/community/tests/integration_tests/chat_models/test_bedrock.py b/libs/community/tests/integration_tests/chat_models/test_bedrock.py index 301260803d9..aa1dbaf8be7 100644 --- a/libs/community/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/community/tests/integration_tests/chat_models/test_bedrock.py @@ -1,9 +1,14 @@ """Test Bedrock chat model.""" -from typing import Any +from typing import Any, cast import pytest from langchain_core.callbacks import CallbackManager -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_community.chat_models import BedrockChat @@ -39,6 +44,20 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None: assert generation.text == generation.message.content +@pytest.mark.scheduled +def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None: + """Test BedrockChat wrapper with generate.""" + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert isinstance(response.llm_output, dict) + + usage = response.llm_output["usage"] + assert usage["prompt_tokens"] == 20 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] > 0 + + @pytest.mark.scheduled def test_chat_bedrock_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" @@ -80,15 +99,18 @@ def test_chat_bedrock_streaming_generation_info() -> None: list(chat.stream("hi")) generation = callback.saved_things["generation"] # `Hello!` is two tokens, assert that that is what is returned - assert generation.generations[0][0].text == " Hello!" + assert generation.generations[0][0].text == "Hello!" @pytest.mark.scheduled def test_bedrock_streaming(chat: BedrockChat) -> None: """Test streaming tokens from OpenAI.""" + full = None for token in chat.stream("I'm Pickle Rick"): + full = token if full is None else full + token assert isinstance(token.content, str) + assert isinstance(cast(AIMessageChunk, full).content, str) @pytest.mark.scheduled @@ -137,3 +159,5 @@ def test_bedrock_invoke(chat: BedrockChat) -> None: """Test invoke tokens from BedrockChat.""" result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + assert all([k in result.response_metadata for k in ("usage", "model_id")]) + assert result.response_metadata["usage"]["prompt_tokens"] == 13