mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 04:25:46 +00:00
groq: add back streaming tool calls (#26391)
api no longer throws an error https://console.groq.com/docs/tool-use#streaming
This commit is contained in:
parent
396c0aee4d
commit
54c85087e2
@ -52,7 +52,6 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
JsonOutputParser,
|
JsonOutputParser,
|
||||||
PydanticOutputParser,
|
PydanticOutputParser,
|
||||||
@ -385,9 +384,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
values["temperature"] = 1e-8
|
values["temperature"] = 1e-8
|
||||||
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": values["groq_api_key"].get_secret_value()
|
"api_key": (
|
||||||
|
values["groq_api_key"].get_secret_value()
|
||||||
if values["groq_api_key"]
|
if values["groq_api_key"]
|
||||||
else None,
|
else None
|
||||||
|
),
|
||||||
"base_url": values["groq_api_base"],
|
"base_url": values["groq_api_base"],
|
||||||
"timeout": values["request_timeout"],
|
"timeout": values["request_timeout"],
|
||||||
"max_retries": values["max_retries"],
|
"max_retries": values["max_retries"],
|
||||||
@ -502,42 +503,6 @@ class ChatGroq(BaseChatModel):
|
|||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
|
||||||
# groq api does not support streaming with tools yet
|
|
||||||
if "tools" in kwargs:
|
|
||||||
response = self.client.create(
|
|
||||||
messages=message_dicts, **{**params, **kwargs}
|
|
||||||
)
|
|
||||||
chat_result = self._create_chat_result(response)
|
|
||||||
generation = chat_result.generations[0]
|
|
||||||
message = cast(AIMessage, generation.message)
|
|
||||||
tool_call_chunks = [
|
|
||||||
create_tool_call_chunk(
|
|
||||||
name=rtc["function"].get("name"),
|
|
||||||
args=rtc["function"].get("arguments"),
|
|
||||||
id=rtc.get("id"),
|
|
||||||
index=rtc.get("index"),
|
|
||||||
)
|
|
||||||
for rtc in message.additional_kwargs.get("tool_calls", [])
|
|
||||||
]
|
|
||||||
chunk_ = ChatGenerationChunk(
|
|
||||||
message=AIMessageChunk(
|
|
||||||
content=message.content,
|
|
||||||
additional_kwargs=message.additional_kwargs,
|
|
||||||
tool_call_chunks=tool_call_chunks,
|
|
||||||
usage_metadata=message.usage_metadata,
|
|
||||||
),
|
|
||||||
generation_info=generation.generation_info,
|
|
||||||
)
|
|
||||||
if run_manager:
|
|
||||||
geninfo = chunk_.generation_info or {}
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
chunk_.text,
|
|
||||||
chunk=chunk_,
|
|
||||||
logprobs=geninfo.get("logprobs"),
|
|
||||||
)
|
|
||||||
yield chunk_
|
|
||||||
return
|
|
||||||
|
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
@ -574,42 +539,6 @@ class ChatGroq(BaseChatModel):
|
|||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
|
||||||
# groq api does not support streaming with tools yet
|
|
||||||
if "tools" in kwargs:
|
|
||||||
response = await self.async_client.create(
|
|
||||||
messages=message_dicts, **{**params, **kwargs}
|
|
||||||
)
|
|
||||||
chat_result = self._create_chat_result(response)
|
|
||||||
generation = chat_result.generations[0]
|
|
||||||
message = cast(AIMessage, generation.message)
|
|
||||||
tool_call_chunks = [
|
|
||||||
{
|
|
||||||
"name": rtc["function"].get("name"),
|
|
||||||
"args": rtc["function"].get("arguments"),
|
|
||||||
"id": rtc.get("id"),
|
|
||||||
"index": rtc.get("index"),
|
|
||||||
}
|
|
||||||
for rtc in message.additional_kwargs.get("tool_calls", [])
|
|
||||||
]
|
|
||||||
chunk_ = ChatGenerationChunk(
|
|
||||||
message=AIMessageChunk(
|
|
||||||
content=message.content,
|
|
||||||
additional_kwargs=message.additional_kwargs,
|
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
|
||||||
usage_metadata=message.usage_metadata,
|
|
||||||
),
|
|
||||||
generation_info=generation.generation_info,
|
|
||||||
)
|
|
||||||
if run_manager:
|
|
||||||
geninfo = chunk_.generation_info or {}
|
|
||||||
await run_manager.on_llm_new_token(
|
|
||||||
chunk_.text,
|
|
||||||
chunk=chunk_,
|
|
||||||
logprobs=geninfo.get("logprobs"),
|
|
||||||
)
|
|
||||||
yield chunk_
|
|
||||||
return
|
|
||||||
|
|
||||||
params = {**params, **kwargs, "stream": True}
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
|
@ -1,42 +1,6 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: TestGroqStandard.test_serdes[serialized]
|
# name: TestGroqStandard.test_serdes[serialized]
|
||||||
dict({
|
dict({
|
||||||
'graph': dict({
|
|
||||||
'edges': list([
|
|
||||||
dict({
|
|
||||||
'source': 0,
|
|
||||||
'target': 1,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'source': 1,
|
|
||||||
'target': 2,
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
'nodes': list([
|
|
||||||
dict({
|
|
||||||
'data': 'ChatGroqInput',
|
|
||||||
'id': 0,
|
|
||||||
'type': 'schema',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'id': list([
|
|
||||||
'langchain_groq',
|
|
||||||
'chat_models',
|
|
||||||
'ChatGroq',
|
|
||||||
]),
|
|
||||||
'name': 'ChatGroq',
|
|
||||||
}),
|
|
||||||
'id': 1,
|
|
||||||
'type': 'runnable',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': 'ChatGroqOutput',
|
|
||||||
'id': 2,
|
|
||||||
'type': 'schema',
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
'id': list([
|
'id': list([
|
||||||
'langchain_groq',
|
'langchain_groq',
|
||||||
'chat_models',
|
'chat_models',
|
||||||
|
Loading…
Reference in New Issue
Block a user