diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index a41da6e1d31..10f173b7f1b 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -307,7 +307,7 @@ class ChatGroq(BaseChatModel): ) chat_result = self._create_chat_result(response) generation = chat_result.generations[0] - message = generation.message + message = cast(AIMessage, generation.message) tool_call_chunks = [ { "name": rtc["function"].get("name"), @@ -322,6 +322,7 @@ class ChatGroq(BaseChatModel): content=message.content, additional_kwargs=message.additional_kwargs, tool_call_chunks=tool_call_chunks, + usage_metadata=message.usage_metadata, ), generation_info=generation.generation_info, ) @@ -337,30 +338,30 @@ class ChatGroq(BaseChatModel): params = {**params, **kwargs, "stream": True} - default_chunk_class = AIMessageChunk + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk for chunk in self.client.create(messages=message_dicts, **params): if not isinstance(chunk, dict): chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) + message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None ) if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) - yield chunk + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk, logprobs=logprobs + ) + yield generation_chunk async def _astream( self, @@ -378,7 +379,7 @@ class ChatGroq(BaseChatModel): ) chat_result = self._create_chat_result(response) generation = chat_result.generations[0] - message = generation.message + message = cast(AIMessage, generation.message) tool_call_chunks = [ { "name": rtc["function"].get("name"), @@ -393,6 +394,7 @@ class ChatGroq(BaseChatModel): content=message.content, additional_kwargs=message.additional_kwargs, tool_call_chunks=tool_call_chunks, + usage_metadata=message.usage_metadata, ), generation_info=generation.generation_info, ) @@ -408,7 +410,7 @@ class ChatGroq(BaseChatModel): params = {**params, **kwargs, "stream": True} - default_chunk_class = AIMessageChunk + default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk async for chunk in await self.async_client.create( messages=message_dicts, **params ): @@ -417,25 +419,25 @@ class ChatGroq(BaseChatModel): if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) + message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class) generation_info = {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason logprobs = choice.get("logprobs") if logprobs: generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None + default_chunk_class = message_chunk.__class__ + generation_chunk = ChatGenerationChunk( + message=message_chunk, generation_info=generation_info or None ) if run_manager: await run_manager.on_llm_new_token( - token=chunk.text, chunk=chunk, logprobs=logprobs + token=generation_chunk.text, + chunk=generation_chunk, + logprobs=logprobs, ) - yield chunk + yield generation_chunk # # Internal methods @@ -459,8 +461,19 @@ class ChatGroq(BaseChatModel): generations = [] if not isinstance(response, dict): response = response.dict() + token_usage = response.get("usage", {}) for res in response["choices"]: message = _convert_dict_to_message(res["message"]) + if token_usage and isinstance(message, AIMessage): + input_tokens = token_usage.get("prompt_tokens", 0) + output_tokens = token_usage.get("completion_tokens", 0) + message.usage_metadata = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": token_usage.get( + "total_tokens", input_tokens + output_tokens + ), + } generation_info = dict(finish_reason=res.get("finish_reason")) if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] @@ -469,7 +482,6 @@ class ChatGroq(BaseChatModel): generation_info=generation_info, ) generations.append(gen) - token_usage = response.get("usage", {}) llm_output = { "token_usage": token_usage, "model_name": self.model_name, @@ -892,9 +904,11 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +def _convert_chunk_to_message_chunk( + chunk: Mapping[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: + choice = chunk["choices"][0] + _dict = choice["delta"] role = cast(str, _dict.get("role")) content = cast(str, _dict.get("content") or "") additional_kwargs: Dict = {} @@ -909,7 +923,21 @@ def _convert_delta_to_message_chunk( if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + if usage := (chunk.get("x_groq") or {}).get("usage"): + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + usage_metadata = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": usage.get("total_tokens", input_tokens + output_tokens), + } + else: + usage_metadata = None + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + usage_metadata=usage_metadata, + ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "function" or default_class == FunctionMessageChunk: diff --git a/libs/partners/groq/poetry.lock b/libs/partners/groq/poetry.lock index be2e64be125..354e68ddad1 100644 --- a/libs/partners/groq/poetry.lock +++ b/libs/partners/groq/poetry.lock @@ -323,7 +323,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.4" +version = "0.2.5" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -332,15 +332,12 @@ develop = true [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.66" +langsmith = "^0.1.75" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" tenacity = "^8.1.0" -[package.extras] -extended-testing = ["jinja2 (>=3,<4)"] - [package.source] type = "directory" url = "../../core" @@ -364,13 +361,13 @@ url = "../../standard-tests" [[package]] name = "langsmith" -version = "0.1.73" +version = "0.1.76" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"}, - {file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"}, + {file = "langsmith-0.1.76-py3-none-any.whl", hash = "sha256:4b8cb14f2233d9673ce9e6e3d545359946d9690a2c1457ab01e7459ec97b964e"}, + {file = "langsmith-0.1.76.tar.gz", hash = "sha256:5829f997495c0f9a39f91fe0a57e0cb702e8642e6948945f5bb9f46337db7732"}, ] [package.dependencies] @@ -918,4 +915,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "672ecb755a4d938d114d4ffa96455758ecc05943c06e49e9bad3dfe65ee3c810" +content-hash = "3cbd3deff4e93bc6337655edfbb328e3e2d5c3dff337ce911c4327f39bc231f9" diff --git a/libs/partners/groq/pyproject.toml b/libs/partners/groq/pyproject.toml index 4e22a9bb120..15f2334a804 100644 --- a/libs/partners/groq/pyproject.toml +++ b/libs/partners/groq/pyproject.toml @@ -12,7 +12,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = ">=0.2.0,<0.3" +langchain-core = ">=0.2.2,<0.3" groq = ">=0.4.1,<1" [tool.poetry.group.test] diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 047497c5d1d..f96d1a22975 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test ChatGroq chat model.""" import json -from typing import Any +from typing import Any, Optional import pytest from langchain_core.messages import ( @@ -93,9 +93,28 @@ async def test_astream() -> None: """Test streaming tokens from Groq.""" chat = ChatGroq(max_tokens=10) + full: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 async for token in chat.astream("Welcome to the Groqetship!"): - assert isinstance(token, BaseMessageChunk) + assert isinstance(token, AIMessageChunk) assert isinstance(token.content, str) + full = token if full is None else full + token + if token.usage_metadata is not None: + chunks_with_token_counts += 1 + if chunks_with_token_counts != 1: + raise AssertionError( + "Expected exactly one chunk with token counts. " + "AIMessageChunk aggregation adds counts. Check that " + "this is behaving properly." + ) + assert isinstance(full, AIMessageChunk) + assert full.usage_metadata is not None + assert full.usage_metadata["input_tokens"] > 0 + assert full.usage_metadata["output_tokens"] > 0 + assert ( + full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"] + == full.usage_metadata["total_tokens"] + ) # diff --git a/libs/partners/groq/tests/integration_tests/test_standard.py b/libs/partners/groq/tests/integration_tests/test_standard.py index 8224adc3ec8..b458d8adf19 100644 --- a/libs/partners/groq/tests/integration_tests/test_standard.py +++ b/libs/partners/groq/tests/integration_tests/test_standard.py @@ -9,22 +9,11 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_groq import ChatGroq -class TestMistralStandard(ChatModelIntegrationTests): +class TestGroqStandard(ChatModelIntegrationTests): @pytest.fixture def chat_model_class(self) -> Type[BaseChatModel]: return ChatGroq - @pytest.mark.xfail(reason="Not implemented.") - def test_usage_metadata( - self, - chat_model_class: Type[BaseChatModel], - chat_model_params: dict, - ) -> None: - super().test_usage_metadata( - chat_model_class, - chat_model_params, - ) - @pytest.mark.xfail(reason="Not yet implemented.") def test_tool_message_histories_list_content( self,