mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
google-genai[patch]: fix streaming, function calling (#17268)
This commit is contained in:
parent
96b5711a0c
commit
e4da7918f3
@ -35,6 +35,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
@ -339,11 +340,23 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|||||||
parts = _convert_to_parts(message.content)
|
parts = _convert_to_parts(message.content)
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
role = "user"
|
role = "user"
|
||||||
|
response: Any
|
||||||
|
if not isinstance(message.content, str):
|
||||||
|
response = message.content
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
response = json.loads(message.content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
response = message.content # leave as str representation
|
||||||
parts = [
|
parts = [
|
||||||
glm.Part(
|
glm.Part(
|
||||||
function_response=glm.FunctionResponse(
|
function_response=glm.FunctionResponse(
|
||||||
name=message.name,
|
name=message.name,
|
||||||
response=message.content,
|
response=(
|
||||||
|
{"output": response}
|
||||||
|
if not isinstance(response, dict)
|
||||||
|
else response
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -364,12 +377,16 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
def _parse_response_candidate(
|
||||||
|
response_candidate: glm.Candidate, stream: bool
|
||||||
|
) -> AIMessage:
|
||||||
first_part = response_candidate.content.parts[0]
|
first_part = response_candidate.content.parts[0]
|
||||||
if first_part.function_call:
|
if first_part.function_call:
|
||||||
function_call = proto.Message.to_dict(first_part.function_call)
|
function_call = proto.Message.to_dict(first_part.function_call)
|
||||||
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
||||||
return AIMessage(content="", additional_kwargs={"function_call": function_call})
|
return (AIMessageChunk if stream else AIMessage)(
|
||||||
|
content="", additional_kwargs={"function_call": function_call}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
parts = response_candidate.content.parts
|
parts = response_candidate.content.parts
|
||||||
|
|
||||||
@ -377,11 +394,14 @@ def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
|||||||
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
||||||
else:
|
else:
|
||||||
content = [proto.Message.to_dict(part) for part in parts]
|
content = [proto.Message.to_dict(part) for part in parts]
|
||||||
return AIMessage(content=content, additional_kwargs={})
|
return (AIMessageChunk if stream else AIMessage)(
|
||||||
|
content=content, additional_kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _response_to_result(
|
def _response_to_result(
|
||||||
response: glm.GenerateContentResponse,
|
response: glm.GenerateContentResponse,
|
||||||
|
stream: bool = False,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Converts a PaLM API response into a LangChain ChatResult."""
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||||
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
||||||
@ -397,8 +417,8 @@ def _response_to_result(
|
|||||||
for safety_rating in candidate.safety_ratings
|
for safety_rating in candidate.safety_ratings
|
||||||
]
|
]
|
||||||
generations.append(
|
generations.append(
|
||||||
ChatGeneration(
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
||||||
message=_parse_response_candidate(candidate),
|
message=_parse_response_candidate(candidate, stream=stream),
|
||||||
generation_info=generation_info,
|
generation_info=generation_info,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -411,7 +431,10 @@ def _response_to_result(
|
|||||||
f"Feedback: {response.prompt_feedback}"
|
f"Feedback: {response.prompt_feedback}"
|
||||||
)
|
)
|
||||||
generations = [
|
generations = [
|
||||||
ChatGeneration(message=AIMessage(content=""), generation_info={})
|
(ChatGenerationChunk if stream else ChatGeneration)(
|
||||||
|
message=(AIMessageChunk if stream else AIMessage)(content=""),
|
||||||
|
generation_info={},
|
||||||
|
)
|
||||||
]
|
]
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
@ -573,7 +596,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
_chat_result = _response_to_result(chunk)
|
_chat_result = _response_to_result(chunk, stream=True)
|
||||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(gen.text)
|
run_manager.on_llm_new_token(gen.text)
|
||||||
@ -597,7 +620,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
generation_method=chat.send_message_async,
|
generation_method=chat.send_message_async,
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
_chat_result = _response_to_result(chunk)
|
_chat_result = _response_to_result(chunk, stream=True)
|
||||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(gen.text)
|
await run_manager.on_llm_new_token(gen.text)
|
||||||
|
10
libs/partners/google-genai/poetry.lock
generated
10
libs/partners/google-genai/poetry.lock
generated
@ -228,13 +228,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "google-api-core"
|
name = "google-api-core"
|
||||||
version = "2.16.2"
|
version = "2.17.0"
|
||||||
description = "Google API client core library"
|
description = "Google API client core library"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "google-api-core-2.16.2.tar.gz", hash = "sha256:032d37b45d1d6bdaf68fb11ff621e2593263a239fa9246e2e94325f9c47876d2"},
|
{file = "google-api-core-2.17.0.tar.gz", hash = "sha256:de7ef0450faec7c75e0aea313f29ac870fdc44cfaec9d6499a9a17305980ef66"},
|
||||||
{file = "google_api_core-2.16.2-py3-none-any.whl", hash = "sha256:449ca0e3f14c179b4165b664256066c7861610f70b6ffe54bb01a04e9b466929"},
|
{file = "google_api_core-2.17.0-py3-none-any.whl", hash = "sha256:08ed79ed8e93e329de5e3e7452746b734e6bf8438d8d64dd3319d21d3164890c"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -448,7 +448,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.19"
|
version = "0.1.21"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -458,7 +458,7 @@ develop = true
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
anyio = ">=3,<5"
|
anyio = ">=3,<5"
|
||||||
jsonpatch = "^1.33"
|
jsonpatch = "^1.33"
|
||||||
langsmith = ">=0.0.83,<0.1"
|
langsmith = "^0.0.87"
|
||||||
packaging = "^23.2"
|
packaging = "^23.2"
|
||||||
pydantic = ">=1,<3"
|
pydantic = ">=1,<3"
|
||||||
PyYAML = ">=5.3"
|
PyYAML = ">=5.3"
|
||||||
|
@ -1,5 +1,14 @@
|
|||||||
"""Test chat model integration."""
|
"""Test chat model integration."""
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
from pytest import CaptureFixture
|
from pytest import CaptureFixture
|
||||||
|
|
||||||
@ -58,3 +67,9 @@ def test_parse_history() -> None:
|
|||||||
"parts": [{"text": system_input}, {"text": text_question1}],
|
"parts": [{"text": system_input}, {"text": text_question1}],
|
||||||
}
|
}
|
||||||
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
|
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"])
|
||||||
|
def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> None:
|
||||||
|
function_message = FunctionMessage(name="search_tool", content=content)
|
||||||
|
_parse_chat_history([function_message], convert_system_message_to_human=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user