diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py index 32255d268b7..48709faebd7 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterator, List, Optional, Union, cast from urllib.parse import urlparse +import proto # type: ignore[import-untyped] import requests from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall @@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage: first_part = response_candidate.content.parts[0] if first_part.function_call: function_call = {"name": first_part.function_call.name} - # dump to match other function calling llm for now + function_call_args_dict = proto.Message.to_dict(first_part.function_call)[ + "args" + ] function_call["arguments"] = json.dumps( - {k: first_part.function_call.args[k] for k in first_part.function_call.args} + {k: function_call_args_dict[k] for k in function_call_args_dict} ) additional_kwargs["function_call"] = function_call return AIMessage(content=content, additional_kwargs=additional_kwargs) diff --git a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py index 6bcbb6e5abf..f24f418ba62 100644 --- a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py @@ -1,22 +1,35 @@ """Test chat model integration.""" +import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch import pytest +from google.cloud.aiplatform_v1beta1.types import ( + Content, + FunctionCall, + Part, +) +from google.cloud.aiplatform_v1beta1.types import ( + content as gapic_content_types, +) from langchain_core.messages import ( AIMessage, HumanMessage, SystemMessage, ) from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore +from vertexai.preview.generative_models import ( # type: ignore + Candidate, +) from langchain_google_vertexai.chat_models import ( ChatVertexAI, _parse_chat_history, _parse_chat_history_gemini, _parse_examples, + _parse_response_candidate, ) @@ -202,3 +215,104 @@ def test_default_params_gemini() -> None: message = HumanMessage(content=user_prompt) _ = model([message]) mock_start_chat.assert_called_once_with(history=[]) + + +@pytest.mark.parametrize( + "raw_candidate, expected", + [ + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"name": "Ben"}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"name": "Ben"}, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"info": ["A", "B", "C"]}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"info": ["A", "B", "C"]}, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={ + "people": [ + {"name": "Joe", "age": 30}, + {"name": "Martha"}, + ] + }, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": { + "people": [ + {"name": "Joe", "age": 30}, + {"name": "Martha"}, + ] + }, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"info": [[1, 2, 3], [4, 5, 6]]}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"info": [[1, 2, 3], [4, 5, 6]]}, + }, + ), + ], +) +def test_parse_response_candidate(raw_candidate, expected) -> None: + response_candidate = Candidate._from_gapic(raw_candidate) + result = _parse_response_candidate(response_candidate) + result_arguments = json.loads( + result.additional_kwargs["function_call"]["arguments"] + ) + + assert result_arguments == expected["arguments"]