groq[patch]: cherry-pick 26391 into v0.3rc (#26397)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
ccurme
2024-09-12 14:24:24 -04:00
committed by GitHub
parent 40df0249fb
commit 35aacff6c6

View File

@@ -52,7 +52,6 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
@@ -387,9 +386,9 @@ class ChatGroq(BaseChatModel):
self.temperature = 1e-8
client_params: Dict[str, Any] = {
"api_key": self.groq_api_key.get_secret_value()
if self.groq_api_key
else None,
"api_key": (
self.groq_api_key.get_secret_value() if self.groq_api_key else None
),
"base_url": self.groq_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
@@ -504,42 +503,6 @@ class ChatGroq(BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
# groq api does not support streaming with tools yet
if "tools" in kwargs:
response = self.client.create(
messages=message_dicts, **{**params, **kwargs}
)
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = cast(AIMessage, generation.message)
tool_call_chunks = [
create_tool_call_chunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc.get("index"),
)
for rtc in message.additional_kwargs.get("tool_calls", [])
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=message.usage_metadata,
),
generation_info=generation.generation_info,
)
if run_manager:
geninfo = chunk_.generation_info or {}
run_manager.on_llm_new_token(
chunk_.text,
chunk=chunk_,
logprobs=geninfo.get("logprobs"),
)
yield chunk_
return
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
@@ -576,42 +539,6 @@ class ChatGroq(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
# groq api does not support streaming with tools yet
if "tools" in kwargs:
response = await self.async_client.create(
messages=message_dicts, **{**params, **kwargs}
)
chat_result = self._create_chat_result(response)
generation = chat_result.generations[0]
message = cast(AIMessage, generation.message)
tool_call_chunks = [
{
"name": rtc["function"].get("name"),
"args": rtc["function"].get("arguments"),
"id": rtc.get("id"),
"index": rtc.get("index"),
}
for rtc in message.additional_kwargs.get("tool_calls", [])
]
chunk_ = ChatGenerationChunk(
message=AIMessageChunk(
content=message.content,
additional_kwargs=message.additional_kwargs,
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
usage_metadata=message.usage_metadata,
),
generation_info=generation.generation_info,
)
if run_manager:
geninfo = chunk_.generation_info or {}
await run_manager.on_llm_new_token(
chunk_.text,
chunk=chunk_,
logprobs=geninfo.get("logprobs"),
)
yield chunk_
return
params = {**params, **kwargs, "stream": True}
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk