mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community[minor]: add bedrock anthropic callback for token usage counting (#19864)
**Description:** add bedrock anthropic callback for token usage counting, consulted openai callback. --------- Co-authored-by: Massimiliano Pronesti <massimiliano.pronesti@gmail.com>
This commit is contained in:
parent
1f9f4d8742
commit
ad9750403b
@ -0,0 +1,111 @@
|
|||||||
|
import threading
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
MODEL_COST_PER_1K_INPUT_TOKENS = {
|
||||||
|
"anthropic.claude-instant-v1": 0.0008,
|
||||||
|
"anthropic.claude-v2": 0.008,
|
||||||
|
"anthropic.claude-v2:1": 0.008,
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0": 0.003,
|
||||||
|
"anthropic.claude-3-haiku-20240307-v1:0": 0.00025,
|
||||||
|
}
|
||||||
|
|
||||||
|
MODEL_COST_PER_1K_OUTPUT_TOKENS = {
|
||||||
|
"anthropic.claude-instant-v1": 0.0024,
|
||||||
|
"anthropic.claude-v2": 0.024,
|
||||||
|
"anthropic.claude-v2:1": 0.024,
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0": 0.015,
|
||||||
|
"anthropic.claude-3-haiku-20240307-v1:0": 0.00125,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_anthropic_claude_token_cost(
|
||||||
|
prompt_tokens: int, completion_tokens: int, model_id: Union[str, None]
|
||||||
|
) -> float:
|
||||||
|
"""Get the cost of tokens for the Claude model."""
|
||||||
|
if model_id not in MODEL_COST_PER_1K_INPUT_TOKENS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown model: {model_id}. Please provide a valid Anthropic model name."
|
||||||
|
"Known models are: " + ", ".join(MODEL_COST_PER_1K_INPUT_TOKENS.keys())
|
||||||
|
)
|
||||||
|
return (prompt_tokens / 1000) * MODEL_COST_PER_1K_INPUT_TOKENS[model_id] + (
|
||||||
|
completion_tokens / 1000
|
||||||
|
) * MODEL_COST_PER_1K_OUTPUT_TOKENS[model_id]
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockAnthropicTokenUsageCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that tracks bedrock anthropic info."""
|
||||||
|
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
successful_requests: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Tokens Used: {self.total_tokens}\n"
|
||||||
|
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||||
|
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||||
|
f"Successful Requests: {self.successful_requests}\n"
|
||||||
|
f"Total Cost (USD): ${self.total_cost}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def always_verbose(self) -> bool:
|
||||||
|
"""Whether to call verbose callbacks even if verbose is False."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Print out the prompts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Print out the token."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""Collect token usage."""
|
||||||
|
if response.llm_output is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "usage" not in response.llm_output:
|
||||||
|
with self._lock:
|
||||||
|
self.successful_requests += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# compute tokens and cost for this request
|
||||||
|
token_usage = response.llm_output["usage"]
|
||||||
|
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
total_tokens = token_usage.get("total_tokens", 0)
|
||||||
|
model_id = response.llm_output.get("model_id", None)
|
||||||
|
total_cost = _get_anthropic_claude_token_cost(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
model_id=model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# update shared state behind lock
|
||||||
|
with self._lock:
|
||||||
|
self.total_cost += total_cost
|
||||||
|
self.total_tokens += total_tokens
|
||||||
|
self.prompt_tokens += prompt_tokens
|
||||||
|
self.completion_tokens += completion_tokens
|
||||||
|
self.successful_requests += 1
|
||||||
|
|
||||||
|
def __copy__(self) -> "BedrockAnthropicTokenUsageCallbackHandler":
|
||||||
|
"""Return a copy of the callback handler."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo: Any) -> "BedrockAnthropicTokenUsageCallbackHandler":
|
||||||
|
"""Return a deep copy of the callback handler."""
|
||||||
|
return self
|
@ -10,6 +10,9 @@ from typing import (
|
|||||||
|
|
||||||
from langchain_core.tracers.context import register_configure_hook
|
from langchain_core.tracers.context import register_configure_hook
|
||||||
|
|
||||||
|
from langchain_community.callbacks.bedrock_anthropic_callback import (
|
||||||
|
BedrockAnthropicTokenUsageCallbackHandler,
|
||||||
|
)
|
||||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||||
from langchain_community.callbacks.tracers.comet import CometTracer
|
from langchain_community.callbacks.tracers.comet import CometTracer
|
||||||
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
from langchain_community.callbacks.tracers.wandb import WandbTracer
|
||||||
@ -19,7 +22,10 @@ logger = logging.getLogger(__name__)
|
|||||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||||
"openai_callback", default=None
|
"openai_callback", default=None
|
||||||
)
|
)
|
||||||
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar( # noqa: E501
|
bedrock_anthropic_callback_var: (ContextVar)[
|
||||||
|
Optional[BedrockAnthropicTokenUsageCallbackHandler]
|
||||||
|
] = ContextVar("bedrock_anthropic_callback", default=None)
|
||||||
|
wandb_tracing_callback_var: ContextVar[Optional[WandbTracer]] = ContextVar(
|
||||||
"tracing_wandb_callback", default=None
|
"tracing_wandb_callback", default=None
|
||||||
)
|
)
|
||||||
comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar( # noqa: E501
|
comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar( # noqa: E501
|
||||||
@ -27,6 +33,7 @@ comet_tracing_callback_var: ContextVar[Optional[CometTracer]] = ContextVar( # n
|
|||||||
)
|
)
|
||||||
|
|
||||||
register_configure_hook(openai_callback_var, True)
|
register_configure_hook(openai_callback_var, True)
|
||||||
|
register_configure_hook(bedrock_anthropic_callback_var, True)
|
||||||
register_configure_hook(
|
register_configure_hook(
|
||||||
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
|
wandb_tracing_callback_var, True, WandbTracer, "LANGCHAIN_WANDB_TRACING"
|
||||||
)
|
)
|
||||||
@ -53,6 +60,27 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
|
|||||||
openai_callback_var.set(None)
|
openai_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_bedrock_anthropic_callback() -> (
|
||||||
|
Generator[BedrockAnthropicTokenUsageCallbackHandler, None, None]
|
||||||
|
):
|
||||||
|
"""Get the Bedrock anthropic callback handler in a context manager.
|
||||||
|
which conveniently exposes token and cost information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BedrockAnthropicTokenUsageCallbackHandler:
|
||||||
|
The Bedrock anthropic callback handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with get_bedrock_anthropic_callback() as cb:
|
||||||
|
... # Use the Bedrock anthropic callback handler
|
||||||
|
"""
|
||||||
|
cb = BedrockAnthropicTokenUsageCallbackHandler()
|
||||||
|
bedrock_anthropic_callback_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
bedrock_anthropic_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def wandb_tracing_enabled(
|
def wandb_tracing_enabled(
|
||||||
session_name: str = "default",
|
session_name: str = "default",
|
||||||
|
@ -308,7 +308,7 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
final_output = {}
|
final_output = {}
|
||||||
for output in llm_outputs:
|
for output in llm_outputs:
|
||||||
output = output or {}
|
output = output or {}
|
||||||
usage = output.pop("usage", {})
|
usage = output.get("usage", {})
|
||||||
for token_type, token_count in usage.items():
|
for token_type, token_count in usage.items():
|
||||||
final_usage[token_type] += token_count
|
final_usage[token_type] += token_count
|
||||||
final_output.update(output)
|
final_output.update(output)
|
||||||
|
@ -7,6 +7,7 @@ from langchain_core.outputs import LLMResult
|
|||||||
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
from langchain_core.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||||
|
|
||||||
from langchain_community.callbacks import get_openai_callback
|
from langchain_community.callbacks import get_openai_callback
|
||||||
|
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
|
||||||
from langchain_community.llms.openai import BaseOpenAI
|
from langchain_community.llms.openai import BaseOpenAI
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +78,37 @@ def test_callback_manager_configure_context_vars(
|
|||||||
)
|
)
|
||||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||||
|
|
||||||
|
# The callback handler has been updated
|
||||||
|
assert cb.successful_requests == 1
|
||||||
|
assert cb.total_tokens == 3
|
||||||
|
assert cb.prompt_tokens == 2
|
||||||
|
assert cb.completion_tokens == 1
|
||||||
|
assert cb.total_cost > 0
|
||||||
|
|
||||||
|
with get_bedrock_anthropic_callback() as cb:
|
||||||
|
# This is a new empty callback handler
|
||||||
|
assert cb.successful_requests == 0
|
||||||
|
assert cb.total_tokens == 0
|
||||||
|
|
||||||
|
# configure adds this bedrock anthropic cb,
|
||||||
|
# but doesn't modify the group manager
|
||||||
|
mngr = CallbackManager.configure(group_manager)
|
||||||
|
assert mngr.handlers == [tracer, cb]
|
||||||
|
assert group_manager.handlers == [tracer]
|
||||||
|
|
||||||
|
response = LLMResult(
|
||||||
|
generations=[],
|
||||||
|
llm_output={
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 3,
|
||||||
|
},
|
||||||
|
"model_id": "anthropic.claude-instant-v1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||||
|
|
||||||
# The callback handler has been updated
|
# The callback handler has been updated
|
||||||
assert cb.successful_requests == 1
|
assert cb.successful_requests == 1
|
||||||
assert cb.total_tokens == 3
|
assert cb.total_tokens == 3
|
||||||
|
@ -58,3 +58,32 @@ def test_different_models_bedrock(model_id: str) -> None:
|
|||||||
|
|
||||||
# should not throw an error
|
# should not throw an error
|
||||||
model.invoke("hello there")
|
model.invoke("hello there")
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_combine_llm_output() -> None:
|
||||||
|
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
|
||||||
|
client = MagicMock()
|
||||||
|
llm_outputs = [
|
||||||
|
{
|
||||||
|
"model_id": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"total_tokens": 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_id": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"total_tokens": 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
model = BedrockChat(model_id=model_id, client=client)
|
||||||
|
final_output = model._combine_llm_outputs(llm_outputs)
|
||||||
|
assert final_output["model_id"] == model_id
|
||||||
|
assert final_output["usage"]["completion_tokens"] == 2
|
||||||
|
assert final_output["usage"]["prompt_tokens"] == 4
|
||||||
|
assert final_output["usage"]["total_tokens"] == 6
|
||||||
|
Loading…
Reference in New Issue
Block a user