mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
google-genai[patch]: fix streaming, function calling (#17268)
This commit is contained in:
parent
96b5711a0c
commit
e4da7918f3
@ -35,6 +35,7 @@ from langchain_core.callbacks.manager import (
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
@ -339,11 +340,23 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
||||
parts = _convert_to_parts(message.content)
|
||||
elif isinstance(message, FunctionMessage):
|
||||
role = "user"
|
||||
response: Any
|
||||
if not isinstance(message.content, str):
|
||||
response = message.content
|
||||
else:
|
||||
try:
|
||||
response = json.loads(message.content)
|
||||
except json.JSONDecodeError:
|
||||
response = message.content # leave as str representation
|
||||
parts = [
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=message.name,
|
||||
response=message.content,
|
||||
response=(
|
||||
{"output": response}
|
||||
if not isinstance(response, dict)
|
||||
else response
|
||||
),
|
||||
)
|
||||
)
|
||||
]
|
||||
@ -364,12 +377,16 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
||||
def _parse_response_candidate(
|
||||
response_candidate: glm.Candidate, stream: bool
|
||||
) -> AIMessage:
|
||||
first_part = response_candidate.content.parts[0]
|
||||
if first_part.function_call:
|
||||
function_call = proto.Message.to_dict(first_part.function_call)
|
||||
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
||||
return AIMessage(content="", additional_kwargs={"function_call": function_call})
|
||||
return (AIMessageChunk if stream else AIMessage)(
|
||||
content="", additional_kwargs={"function_call": function_call}
|
||||
)
|
||||
else:
|
||||
parts = response_candidate.content.parts
|
||||
|
||||
@ -377,11 +394,14 @@ def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
||||
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
||||
else:
|
||||
content = [proto.Message.to_dict(part) for part in parts]
|
||||
return AIMessage(content=content, additional_kwargs={})
|
||||
return (AIMessageChunk if stream else AIMessage)(
|
||||
content=content, additional_kwargs={}
|
||||
)
|
||||
|
||||
|
||||
def _response_to_result(
|
||||
response: glm.GenerateContentResponse,
|
||||
stream: bool = False,
|
||||
) -> ChatResult:
|
||||
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
||||
@ -397,8 +417,8 @@ def _response_to_result(
|
||||
for safety_rating in candidate.safety_ratings
|
||||
]
|
||||
generations.append(
|
||||
ChatGeneration(
|
||||
message=_parse_response_candidate(candidate),
|
||||
(ChatGenerationChunk if stream else ChatGeneration)(
|
||||
message=_parse_response_candidate(candidate, stream=stream),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
)
|
||||
@ -411,7 +431,10 @@ def _response_to_result(
|
||||
f"Feedback: {response.prompt_feedback}"
|
||||
)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=""), generation_info={})
|
||||
(ChatGenerationChunk if stream else ChatGeneration)(
|
||||
message=(AIMessageChunk if stream else AIMessage)(content=""),
|
||||
generation_info={},
|
||||
)
|
||||
]
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@ -573,7 +596,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
_chat_result = _response_to_result(chunk)
|
||||
_chat_result = _response_to_result(chunk, stream=True)
|
||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(gen.text)
|
||||
@ -597,7 +620,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
generation_method=chat.send_message_async,
|
||||
stream=True,
|
||||
):
|
||||
_chat_result = _response_to_result(chunk)
|
||||
_chat_result = _response_to_result(chunk, stream=True)
|
||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(gen.text)
|
||||
|
10
libs/partners/google-genai/poetry.lock
generated
10
libs/partners/google-genai/poetry.lock
generated
@ -228,13 +228,13 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4
|
||||
|
||||
[[package]]
|
||||
name = "google-api-core"
|
||||
version = "2.16.2"
|
||||
version = "2.17.0"
|
||||
description = "Google API client core library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google-api-core-2.16.2.tar.gz", hash = "sha256:032d37b45d1d6bdaf68fb11ff621e2593263a239fa9246e2e94325f9c47876d2"},
|
||||
{file = "google_api_core-2.16.2-py3-none-any.whl", hash = "sha256:449ca0e3f14c179b4165b664256066c7861610f70b6ffe54bb01a04e9b466929"},
|
||||
{file = "google-api-core-2.17.0.tar.gz", hash = "sha256:de7ef0450faec7c75e0aea313f29ac870fdc44cfaec9d6499a9a17305980ef66"},
|
||||
{file = "google_api_core-2.17.0-py3-none-any.whl", hash = "sha256:08ed79ed8e93e329de5e3e7452746b734e6bf8438d8d64dd3319d21d3164890c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -448,7 +448,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.19"
|
||||
version = "0.1.21"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -458,7 +458,7 @@ develop = true
|
||||
[package.dependencies]
|
||||
anyio = ">=3,<5"
|
||||
jsonpatch = "^1.33"
|
||||
langsmith = ">=0.0.83,<0.1"
|
||||
langsmith = "^0.0.87"
|
||||
packaging = "^23.2"
|
||||
pydantic = ">=1,<3"
|
||||
PyYAML = ">=5.3"
|
||||
|
@ -1,5 +1,14 @@
|
||||
"""Test chat model integration."""
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
@ -58,3 +67,9 @@ def test_parse_history() -> None:
|
||||
"parts": [{"text": system_input}, {"text": text_question1}],
|
||||
}
|
||||
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"])
|
||||
def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> None:
|
||||
function_message = FunctionMessage(name="search_tool", content=content)
|
||||
_parse_chat_history([function_message], convert_system_message_to_human=True)
|
||||
|
Loading…
Reference in New Issue
Block a user