From 51bdfe04e945bdee1f0ad009bae0575e6d2ae811 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Wed, 3 Apr 2024 15:22:59 -0700 Subject: [PATCH] groq: handle streaming tool call case (#19978) --- .../groq/langchain_groq/chat_models.py | 60 ++++++++++++++++--- .../integration_tests/test_chat_models.py | 59 ++++++++++++++++++ 2 files changed, 111 insertions(+), 8 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index e9ad9239742..e557eb26a56 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -225,11 +225,9 @@ class ChatGroq(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: + if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) @@ -237,7 +235,6 @@ class ChatGroq(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = { **params, - **({"stream": stream} if stream is not None else {}), **kwargs, } response = self.client.create(messages=message_dicts, **params) @@ -248,11 +245,9 @@ class ChatGroq(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: + if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) @@ -261,7 +256,6 @@ class ChatGroq(BaseChatModel): message_dicts, params = self._create_message_dicts(messages, stop) params = { **params, - **({"stream": stream} if stream is not None else {}), **kwargs, } response = await self.async_client.create(messages=message_dicts, **params) @@ -275,6 +269,31 @@ class ChatGroq(BaseChatModel): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) + + # groq api does not support streaming with tools yet + if "tools" in kwargs: + response = self.client.create( + messages=message_dicts, **{**params, **kwargs} + ) + chat_result = self._create_chat_result(response) + generation = chat_result.generations[0] + message = generation.message + chunk_ = ChatGenerationChunk( + message=AIMessageChunk( + content=message.content, additional_kwargs=message.additional_kwargs + ), + generation_info=generation.generation_info, + ) + if run_manager: + geninfo = chunk_.generation_info or {} + run_manager.on_llm_new_token( + chunk_.text, + chunk=chunk_, + logprobs=geninfo.get("logprobs"), + ) + yield chunk_ + return + params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk @@ -310,6 +329,31 @@ class ChatGroq(BaseChatModel): **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) + + # groq api does not support streaming with tools yet + if "tools" in kwargs: + response = await self.async_client.create( + messages=message_dicts, **{**params, **kwargs} + ) + chat_result = self._create_chat_result(response) + generation = chat_result.generations[0] + message = generation.message + chunk_ = ChatGenerationChunk( + message=AIMessageChunk( + content=message.content, additional_kwargs=message.additional_kwargs + ), + generation_info=generation.generation_info, + ) + if run_manager: + geninfo = chunk_.generation_info or {} + await run_manager.on_llm_new_token( + chunk_.text, + chunk=chunk_, + logprobs=geninfo.get("logprobs"), + ) + yield chunk_ + return + params = {**params, **kwargs, "stream": True} default_chunk_class = AIMessageChunk diff --git a/libs/partners/groq/tests/integration_tests/test_chat_models.py b/libs/partners/groq/tests/integration_tests/test_chat_models.py index 3ef7fb1639f..f2065091366 100644 --- a/libs/partners/groq/tests/integration_tests/test_chat_models.py +++ b/libs/partners/groq/tests/integration_tests/test_chat_models.py @@ -6,6 +6,7 @@ from typing import Any import pytest from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, BaseMessageChunk, HumanMessage, @@ -272,6 +273,64 @@ def test_tool_choice_bool() -> None: assert tool_call["type"] == "function" +def test_streaming_tool_call() -> None: + """Test that tool choice is respected.""" + llm = ChatGroq() + + class MyTool(BaseModel): + name: str + age: int + + with_tool = llm.bind_tools([MyTool], tool_choice="MyTool") + + resp = with_tool.stream("Who was the 27 year old named Erick?") + additional_kwargs = None + for chunk in resp: + assert isinstance(chunk, AIMessageChunk) + assert chunk.content == "" # should just be tool call + additional_kwargs = chunk.additional_kwargs + + assert additional_kwargs is not None + tool_calls = additional_kwargs["tool_calls"] + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == "MyTool" + assert json.loads(tool_call["function"]["arguments"]) == { + "age": 27, + "name": "Erick", + } + assert tool_call["type"] == "function" + + +async def test_astreaming_tool_call() -> None: + """Test that tool choice is respected.""" + llm = ChatGroq() + + class MyTool(BaseModel): + name: str + age: int + + with_tool = llm.bind_tools([MyTool], tool_choice="MyTool") + + resp = with_tool.astream("Who was the 27 year old named Erick?") + additional_kwargs = None + async for chunk in resp: + assert isinstance(chunk, AIMessageChunk) + assert chunk.content == "" # should just be tool call + additional_kwargs = chunk.additional_kwargs + + assert additional_kwargs is not None + tool_calls = additional_kwargs["tool_calls"] + assert len(tool_calls) == 1 + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == "MyTool" + assert json.loads(tool_call["function"]["arguments"]) == { + "age": 27, + "name": "Erick", + } + assert tool_call["type"] == "function" + + @pytest.mark.scheduled def test_json_mode_structured_output() -> None: """Test with_structured_output with json"""