mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 14:36:54 +00:00
community[patch]: gather token usage info in BedrockChat during generation (#19127)
This PR allows to calculate token usage for prompts and completion directly in the generation method of BedrockChat. The token usage details are then returned together with the generations, so that other downstream tasks can access them easily. This allows to define a callback for tokens tracking and cost calculation, similarly to what happens with OpenAI (see [OpenAICallbackHandler](https://api.python.langchain.com/en/latest/_modules/langchain_community/callbacks/openai_info.html#OpenAICallbackHandler). I plan on adding a BedrockCallbackHandler later. Right now keeping track of tokens in the callback is already possible, but it requires passing the llm, as done here: https://how.wtf/how-to-count-amazon-bedrock-anthropic-tokens-with-langchain.html. However, I find the approach of this PR cleaner. Thanks for your reviews. FYI @baskaryan, @hwchase17 --------- Co-authored-by: taamedag <Davide.Menini@swisscom.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
@@ -141,6 +142,7 @@ class LLMInputOutputAdapter:
|
||||
|
||||
@classmethod
|
||||
def prepare_output(cls, provider: str, response: Any) -> dict:
|
||||
text = ""
|
||||
if provider == "anthropic":
|
||||
response_body = json.loads(response.get("body").read().decode())
|
||||
if "completion" in response_body:
|
||||
@@ -162,9 +164,17 @@ class LLMInputOutputAdapter:
|
||||
else:
|
||||
text = response_body.get("results")[0].get("outputText")
|
||||
|
||||
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
|
||||
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
|
||||
return {
|
||||
"text": text,
|
||||
"body": response_body,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -498,7 +508,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
provider = self._get_provider()
|
||||
@@ -531,7 +541,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
try:
|
||||
response = self.client.invoke_model(**request_options)
|
||||
|
||||
text, body = LLMInputOutputAdapter.prepare_output(
|
||||
text, body, usage_info = LLMInputOutputAdapter.prepare_output(
|
||||
provider, response
|
||||
).values()
|
||||
|
||||
@@ -554,7 +564,7 @@ class BedrockBase(BaseModel, ABC):
|
||||
**services_trace,
|
||||
)
|
||||
|
||||
return text
|
||||
return text, usage_info
|
||||
|
||||
def _get_bedrock_services_signal(self, body: dict) -> dict:
|
||||
"""
|
||||
@@ -824,9 +834,10 @@ class Bedrock(LLM, BedrockBase):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
return self._prepare_input_and_invoke(
|
||||
text, _ = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return text
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user