community[patch]: gather token usage info in BedrockChat during generation (#19127)

This PR allows to calculate token usage for prompts and completion
directly in the generation method of BedrockChat. The token usage
details are then returned together with the generations, so that other
downstream tasks can access them easily.

This allows to define a callback for tokens tracking and cost
calculation, similarly to what happens with OpenAI (see
[OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler).
I plan on adding a BedrockCallbackHandler later.
Right now keeping track of tokens in the callback is already possible,
but it requires passing the llm, as done here:
https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html.
However, I find the approach of this PR cleaner.

Thanks for your reviews. FYI @baskaryan, @hwchase17

---------

Co-authored-by: taamedag <Davide.Menini@swisscom.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Davide Menini 2024-03-28 19:58:46 +01:00 committed by GitHub
parent a662468dde
commit f7042321f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 15 deletions

View File

@ -1,4 +1,5 @@
import re
from collections import defaultdict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from langchain_core.callbacks import (
@ -234,10 +235,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
system = None
formatted_messages = None
prompt, system, formatted_messages = None, None, None
if provider == "anthropic":
prompt = None
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
@ -265,17 +265,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
**kwargs: Any,
) -> ChatResult:
completion = ""
llm_output: Dict[str, Any] = {"model_id": self.model_id}
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
system = None
formatted_messages = None
prompt, system, formatted_messages = None, None, None
params: Dict[str, Any] = {**kwargs}
if provider == "anthropic":
prompt = None
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
@ -287,7 +287,7 @@ class BedrockChat(BaseChatModel, BedrockBase):
if stop:
params["stop_sequences"] = stop
completion = self._prepare_input_and_invoke(
completion, usage_info = self._prepare_input_and_invoke(
prompt=prompt,
stop=stop,
run_manager=run_manager,
@ -296,10 +296,25 @@ class BedrockChat(BaseChatModel, BedrockBase):
**params,
)
llm_output["usage"] = usage_info
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=completion))]
generations=[ChatGeneration(message=AIMessage(content=completion))],
llm_output=llm_output,
)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
final_usage: Dict[str, int] = defaultdict(int)
final_output = {}
for output in llm_outputs:
output = output or {}
usage = output.pop("usage", {})
for token_type, token_count in usage.items():
final_usage[token_type] += token_count
final_output.update(output)
final_output["usage"] = final_usage
return final_output
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:
return get_num_tokens_anthropic(text)

View File

@ -11,6 +11,7 @@ from typing import (
List,
Mapping,
Optional,
Tuple,
)
from langchain_core.callbacks import (
@ -141,6 +142,7 @@ class LLMInputOutputAdapter:
@classmethod
def prepare_output(cls, provider: str, response: Any) -> dict:
text = ""
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
if "completion" in response_body:
@ -162,9 +164,17 @@ class LLMInputOutputAdapter:
else:
text = response_body.get("results")[0].get("outputText")
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
return {
"text": text,
"body": response_body,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
@classmethod
@ -498,7 +508,7 @@ class BedrockBase(BaseModel, ABC):
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
) -> Tuple[str, Dict[str, Any]]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
@ -531,7 +541,7 @@ class BedrockBase(BaseModel, ABC):
try:
response = self.client.invoke_model(**request_options)
text, body = LLMInputOutputAdapter.prepare_output(
text, body, usage_info = LLMInputOutputAdapter.prepare_output(
provider, response
).values()
@ -554,7 +564,7 @@ class BedrockBase(BaseModel, ABC):
**services_trace,
)
return text
return text, usage_info
def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
@ -824,9 +834,10 @@ class Bedrock(LLM, BedrockBase):
completion += chunk.text
return completion
return self._prepare_input_and_invoke(
text, _ = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
return text
async def _astream(
self,

View File

@ -1,9 +1,14 @@
"""Test Bedrock chat model."""
from typing import Any
from typing import Any, cast
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models import BedrockChat
@ -39,6 +44,20 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None:
assert generation.text == generation.message.content
@pytest.mark.scheduled
def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None:
"""Test BedrockChat wrapper with generate."""
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert isinstance(response.llm_output, dict)
usage = response.llm_output["usage"]
assert usage["prompt_tokens"] == 20
assert usage["completion_tokens"] > 0
assert usage["total_tokens"] > 0
@pytest.mark.scheduled
def test_chat_bedrock_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
@ -80,15 +99,18 @@ def test_chat_bedrock_streaming_generation_info() -> None:
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
assert generation.generations[0][0].text == " Hello!"
assert generation.generations[0][0].text == "Hello!"
@pytest.mark.scheduled
def test_bedrock_streaming(chat: BedrockChat) -> None:
"""Test streaming tokens from OpenAI."""
full = None
for token in chat.stream("I'm Pickle Rick"):
full = token if full is None else full + token
assert isinstance(token.content, str)
assert isinstance(cast(AIMessageChunk, full).content, str)
@pytest.mark.scheduled
@ -137,3 +159,5 @@ def test_bedrock_invoke(chat: BedrockChat) -> None:
"""Test invoke tokens from BedrockChat."""
result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
assert all([k in result.response_metadata for k in ("usage", "model_id")])
assert result.response_metadata["usage"]["prompt_tokens"] == 13