community[patch]: gather token usage info in BedrockChat during generation (#19127)

This PR allows to calculate token usage for prompts and completion
directly in the generation method of BedrockChat. The token usage
details are then returned together with the generations, so that other
downstream tasks can access them easily.

This allows to define a callback for tokens tracking and cost
calculation, similarly to what happens with OpenAI (see
[OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler).
I plan on adding a BedrockCallbackHandler later.
Right now keeping track of tokens in the callback is already possible,
but it requires passing the llm, as done here:
https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html.
However, I find the approach of this PR cleaner.

Thanks for your reviews. FYI @baskaryan, @hwchase17

---------

Co-authored-by: taamedag <Davide.Menini@swisscom.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Davide Menini
2024-03-28 19:58:46 +01:00
committed by GitHub
parent a662468dde
commit f7042321f1
3 changed files with 65 additions and 15 deletions

View File

@@ -11,6 +11,7 @@ from typing import (
List,
Mapping,
Optional,
Tuple,
)
from langchain_core.callbacks import (
@@ -141,6 +142,7 @@ class LLMInputOutputAdapter:
@classmethod
def prepare_output(cls, provider: str, response: Any) -> dict:
text = ""
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
if "completion" in response_body:
@@ -162,9 +164,17 @@ class LLMInputOutputAdapter:
else:
text = response_body.get("results")[0].get("outputText")
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
return {
"text": text,
"body": response_body,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
@classmethod
@@ -498,7 +508,7 @@ class BedrockBase(BaseModel, ABC):
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
) -> Tuple[str, Dict[str, Any]]:
_model_kwargs = self.model_kwargs or {}
provider = self._get_provider()
@@ -531,7 +541,7 @@ class BedrockBase(BaseModel, ABC):
try:
response = self.client.invoke_model(**request_options)
text, body = LLMInputOutputAdapter.prepare_output(
text, body, usage_info = LLMInputOutputAdapter.prepare_output(
provider, response
).values()
@@ -554,7 +564,7 @@ class BedrockBase(BaseModel, ABC):
**services_trace,
)
return text
return text, usage_info
def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
@@ -824,9 +834,10 @@ class Bedrock(LLM, BedrockBase):
completion += chunk.text
return completion
return self._prepare_input_and_invoke(
text, _ = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
return text
async def _astream(
self,