diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 05dc4724570..8f1e1f68046 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -52,7 +52,6 @@ from langchain_core.messages import ( ToolMessage, ToolMessageChunk, ) -from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -387,9 +386,9 @@ class ChatGroq(BaseChatModel): self.temperature = 1e-8 client_params: Dict[str, Any] = { - "api_key": self.groq_api_key.get_secret_value() - if self.groq_api_key - else None, + "api_key": ( + self.groq_api_key.get_secret_value() if self.groq_api_key else None + ), "base_url": self.groq_api_base, "timeout": self.request_timeout, "max_retries": self.max_retries, @@ -504,42 +503,6 @@ class ChatGroq(BaseChatModel): ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) - # groq api does not support streaming with tools yet - if "tools" in kwargs: - response = self.client.create( - messages=message_dicts, **{**params, **kwargs} - ) - chat_result = self._create_chat_result(response) - generation = chat_result.generations[0] - message = cast(AIMessage, generation.message) - tool_call_chunks = [ - create_tool_call_chunk( - name=rtc["function"].get("name"), - args=rtc["function"].get("arguments"), - id=rtc.get("id"), - index=rtc.get("index"), - ) - for rtc in message.additional_kwargs.get("tool_calls", []) - ] - chunk_ = ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - tool_call_chunks=tool_call_chunks, - usage_metadata=message.usage_metadata, - ), - generation_info=generation.generation_info, - ) - if run_manager: - geninfo = chunk_.generation_info or {} - run_manager.on_llm_new_token( - chunk_.text, - chunk=chunk_, - logprobs=geninfo.get("logprobs"), - ) - yield chunk_ - return - params = {**params, **kwargs, "stream": True} default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk @@ -576,42 +539,6 @@ class ChatGroq(BaseChatModel): ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) - # groq api does not support streaming with tools yet - if "tools" in kwargs: - response = await self.async_client.create( - messages=message_dicts, **{**params, **kwargs} - ) - chat_result = self._create_chat_result(response) - generation = chat_result.generations[0] - message = cast(AIMessage, generation.message) - tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc.get("index"), - } - for rtc in message.additional_kwargs.get("tool_calls", []) - ] - chunk_ = ChatGenerationChunk( - message=AIMessageChunk( - content=message.content, - additional_kwargs=message.additional_kwargs, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] - usage_metadata=message.usage_metadata, - ), - generation_info=generation.generation_info, - ) - if run_manager: - geninfo = chunk_.generation_info or {} - await run_manager.on_llm_new_token( - chunk_.text, - chunk=chunk_, - logprobs=geninfo.get("logprobs"), - ) - yield chunk_ - return - params = {**params, **kwargs, "stream": True} default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk