diff --git a/libs/community/langchain_community/chat_models/bedrock.py b/libs/community/langchain_community/chat_models/bedrock.py index 933343d6a96..4cb73455f2d 100644 --- a/libs/community/langchain_community/chat_models/bedrock.py +++ b/libs/community/langchain_community/chat_models/bedrock.py @@ -1,4 +1,5 @@ import re +from collections import defaultdict from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from langchain_core.callbacks import ( @@ -234,10 +235,9 @@ class BedrockChat(BaseChatModel, BedrockBase): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: provider = self._get_provider() - system = None - formatted_messages = None + prompt, system, formatted_messages = None, None, None + if provider == "anthropic": - prompt = None system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) @@ -265,17 +265,17 @@ class BedrockChat(BaseChatModel, BedrockBase): **kwargs: Any, ) -> ChatResult: completion = "" + llm_output: Dict[str, Any] = {"model_id": self.model_id} if self.streaming: for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text else: provider = self._get_provider() - system = None - formatted_messages = None + prompt, system, formatted_messages = None, None, None params: Dict[str, Any] = {**kwargs} + if provider == "anthropic": - prompt = None system, formatted_messages = ChatPromptAdapter.format_messages( provider, messages ) @@ -287,7 +287,7 @@ class BedrockChat(BaseChatModel, BedrockBase): if stop: params["stop_sequences"] = stop - completion = self._prepare_input_and_invoke( + completion, usage_info = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, @@ -296,10 +296,25 @@ class BedrockChat(BaseChatModel, BedrockBase): **params, ) + llm_output["usage"] = usage_info + return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=completion))] + generations=[ChatGeneration(message=AIMessage(content=completion))], + llm_output=llm_output, ) + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + final_usage: Dict[str, int] = defaultdict(int) + final_output = {} + for output in llm_outputs: + output = output or {} + usage = output.pop("usage", {}) + for token_type, token_count in usage.items(): + final_usage[token_type] += token_count + final_output.update(output) + final_output["usage"] = final_usage + return final_output + def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: return get_num_tokens_anthropic(text) diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index 9b7515a5f4d..8ab5ad276ea 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -11,6 +11,7 @@ from typing import ( List, Mapping, Optional, + Tuple, ) from langchain_core.callbacks import ( @@ -141,6 +142,7 @@ class LLMInputOutputAdapter: @classmethod def prepare_output(cls, provider: str, response: Any) -> dict: + text = "" if provider == "anthropic": response_body = json.loads(response.get("body").read().decode()) if "completion" in response_body: @@ -162,9 +164,17 @@ class LLMInputOutputAdapter: else: text = response_body.get("results")[0].get("outputText") + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) return { "text": text, "body": response_body, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, } @classmethod @@ -498,7 +508,7 @@ class BedrockBase(BaseModel, ABC): stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: + ) -> Tuple[str, Dict[str, Any]]: _model_kwargs = self.model_kwargs or {} provider = self._get_provider() @@ -531,7 +541,7 @@ class BedrockBase(BaseModel, ABC): try: response = self.client.invoke_model(**request_options) - text, body = LLMInputOutputAdapter.prepare_output( + text, body, usage_info = LLMInputOutputAdapter.prepare_output( provider, response ).values() @@ -554,7 +564,7 @@ class BedrockBase(BaseModel, ABC): **services_trace, ) - return text + return text, usage_info def _get_bedrock_services_signal(self, body: dict) -> dict: """ @@ -824,9 +834,10 @@ class Bedrock(LLM, BedrockBase): completion += chunk.text return completion - return self._prepare_input_and_invoke( + text, _ = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ) + return text async def _astream( self, diff --git a/libs/community/tests/integration_tests/chat_models/test_bedrock.py b/libs/community/tests/integration_tests/chat_models/test_bedrock.py index 301260803d9..aa1dbaf8be7 100644 --- a/libs/community/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/community/tests/integration_tests/chat_models/test_bedrock.py @@ -1,9 +1,14 @@ """Test Bedrock chat model.""" -from typing import Any +from typing import Any, cast import pytest from langchain_core.callbacks import CallbackManager -from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_community.chat_models import BedrockChat @@ -39,6 +44,20 @@ def test_chat_bedrock_generate(chat: BedrockChat) -> None: assert generation.text == generation.message.content +@pytest.mark.scheduled +def test_chat_bedrock_generate_with_token_usage(chat: BedrockChat) -> None: + """Test BedrockChat wrapper with generate.""" + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert isinstance(response.llm_output, dict) + + usage = response.llm_output["usage"] + assert usage["prompt_tokens"] == 20 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] > 0 + + @pytest.mark.scheduled def test_chat_bedrock_streaming() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" @@ -80,15 +99,18 @@ def test_chat_bedrock_streaming_generation_info() -> None: list(chat.stream("hi")) generation = callback.saved_things["generation"] # `Hello!` is two tokens, assert that that is what is returned - assert generation.generations[0][0].text == " Hello!" + assert generation.generations[0][0].text == "Hello!" @pytest.mark.scheduled def test_bedrock_streaming(chat: BedrockChat) -> None: """Test streaming tokens from OpenAI.""" + full = None for token in chat.stream("I'm Pickle Rick"): + full = token if full is None else full + token assert isinstance(token.content, str) + assert isinstance(cast(AIMessageChunk, full).content, str) @pytest.mark.scheduled @@ -137,3 +159,5 @@ def test_bedrock_invoke(chat: BedrockChat) -> None: """Test invoke tokens from BedrockChat.""" result = chat.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + assert all([k in result.response_metadata for k in ("usage", "model_id")]) + assert result.response_metadata["usage"]["prompt_tokens"] == 13