mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
community[patch]: gather token usage info in BedrockChat during generation (#19127)
This PR allows to calculate token usage for prompts and completion directly in the generation method of BedrockChat. The token usage details are then returned together with the generations, so that other downstream tasks can access them easily. This allows to define a callback for tokens tracking and cost calculation, similarly to what happens with OpenAI (see [OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler). I plan on adding a BedrockCallbackHandler later. Right now keeping track of tokens in the callback is already possible, but it requires passing the llm, as done here: https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html. However, I find the approach of this PR cleaner. Thanks for your reviews. FYI @baskaryan, @hwchase17 --------- Co-authored-by: taamedag <Davide.Menini@swisscom.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a662468dde
commit
f7042321f1
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@ -234,10 +235,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
provider = self._get_provider()
|
||||
system = None
|
||||
formatted_messages = None
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
|
||||
if provider == "anthropic":
|
||||
prompt = None
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
@ -265,17 +265,17 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
completion = ""
|
||||
llm_output: Dict[str, Any] = {"model_id": self.model_id}
|
||||
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
provider = self._get_provider()
|
||||
system = None
|
||||
formatted_messages = None
|
||||
prompt, system, formatted_messages = None, None, None
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
|
||||
if provider == "anthropic":
|
||||
prompt = None
|
||||
system, formatted_messages = ChatPromptAdapter.format_messages(
|
||||
provider, messages
|
||||
)
|
||||
@ -287,7 +287,7 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
|
||||
completion = self._prepare_input_and_invoke(
|
||||
completion, usage_info = self._prepare_input_and_invoke(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
@ -296,10 +296,25 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
**params,
|
||||
)
|
||||
|
||||
llm_output["usage"] = usage_info
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=completion))]
|
||||
generations=[ChatGeneration(message=AIMessage(content=completion))],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
final_usage: Dict[str, int] = defaultdict(int)
|
||||
final_output = {}
|
||||
for output in llm_outputs:
|
||||
output = output or {}
|
||||
usage = output.pop("usage", {})
|
||||
for token_type, token_count in usage.items():
|
||||
final_usage[token_type] += token_count
|
||||
final_output.update(output)
|
||||
final_output["usage"] = final_usage
|
||||
return final_output
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
return get_num_tokens_anthropic(text)
|
||||
|
@ -11,6 +11,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@ -141,6 +142,7 @@ class LLMInputOutputAdapter:
|
||||
|
||||
@classmethod
|
||||
def prepare_output(cls, provider: str, response: Any) -> dict:
|
||||
text = ""
|
||||
if provider == "anthropic":
|
||||
response_body = json.loads(response.get("body").read().decode())
|
||||
if "completion" in response_body:
|
||||
@ -162,9 +164,17 @@ class LLMInputOutputAdapter:
|
||||
else:
|
||||
text = response_body.get("results")[0].get("outputText")
|
||||
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
|
||||
return {
|
||||
"text": text,
|
||||
"body": response_body,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -498,7 +508,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
provider = self._get_provider()
|
||||
@ -531,7 +541,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
try:
|
||||
response = self.client.invoke_model(**request_options)
|
||||
|
||||
text, body = LLMInputOutputAdapter.prepare_output(
|
||||
text, body, usage_info = LLMInputOutputAdapter.prepare_output(
|
||||
provider, response
|
||||
).values()
|
||||
|
||||
@ -554,7 +564,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
**services_trace,
|
||||
)
|
||||
|
||||
return text
|
||||
return text, usage_info
|
||||
|
||||
def _get_bedrock_services_signal(self, body: dict) -> dict:
|
||||
"""
|
||||
@ -824,9 +834,10 @@ class Bedrock(LLM, BedrockBase):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
return self._prepare_input_and_invoke(
|
||||
text, _ = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return text
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
|
@ -1,9 +1,14 @@
|
||||
"""Test Bedrock chat model."""
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
@ -39,6 +44,20 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None:
|
||||
assert generation.text == generation.message.content
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None:
|
||||
"""Test BedrockChat wrapper with generate."""
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert isinstance(response.llm_output, dict)
|
||||
|
||||
usage = response.llm_output["usage"]
|
||||
assert usage["prompt_tokens"] == 20
|
||||
assert usage["completion_tokens"] > 0
|
||||
assert usage["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_bedrock_streaming() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
@ -80,15 +99,18 @@ def test_chat_bedrock_streaming_generation_info() -> None:
|
||||
list(chat.stream("hi"))
|
||||
generation = callback.saved_things["generation"]
|
||||
# `Hello!` is two tokens, assert that that is what is returned
|
||||
assert generation.generations[0][0].text == " Hello!"
|
||||
assert generation.generations[0][0].text == "Hello!"
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_bedrock_streaming(chat: BedrockChat) -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
|
||||
full = None
|
||||
for token in chat.stream("I'm Pickle Rick"):
|
||||
full = token if full is None else full + token
|
||||
assert isinstance(token.content, str)
|
||||
assert isinstance(cast(AIMessageChunk, full).content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@ -137,3 +159,5 @@ def test_bedrock_invoke(chat: BedrockChat) -> None:
|
||||
"""Test invoke tokens from BedrockChat."""
|
||||
result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
assert all([k in result.response_metadata for k in ("usage", "model_id")])
|
||||
assert result.response_metadata["usage"]["prompt_tokens"] == 13
|
||||
|
Loading…
Reference in New Issue
Block a user