From e22c4d4eb0a3b2b1d4d3fb39c7cb78acccc350b2 Mon Sep 17 00:00:00 2001 From: hsuyuming Date: Thu, 8 Feb 2024 12:48:25 -0700 Subject: [PATCH] google-vertexai[patch]: fix _parse_response_candidate issue (#16647) **Description:** enable _parse_response_candidate to support complex structure format. **Issue:** currently, if Gemini response complex args format, people will get "TypeError: Object of type RepeatedComposite is not JSON serializable" error from _parse_response_candidate. response candidate example ``` content { role: "model" parts { function_call { name: "Information" args { fields { key: "people" value { list_value { values { string_value: "Joe is 30, his mom is Martha" } } } } } } } } finish_reason: STOP safety_ratings { category: HARM_CATEGORY_HARASSMENT probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_HATE_SPEECH probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_SEXUALLY_EXPLICIT probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_DANGEROUS_CONTENT probability: NEGLIGIBLE } ``` error msg: ``` Traceback (most recent call last): File "/home/jupyter/user/abehsu/gemini_langchain_tools/example2.py", line 36, in print(tagging_chain.invoke({"input": "Joe is 30, his mom is Martha"})) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2053, in invoke input = step.invoke( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3887, in invoke return self.bound.invoke( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 165, in invoke self.generate_prompt( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 543, in generate_prompt return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 407, in generate raise e File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 397, in generate self._generate_with_cache( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 576, in _generate_with_cache return self._generate( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 406, in _generate generations = [ File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 408, in message=_parse_response_candidate(c), File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 280, in _parse_response_candidate function_call["arguments"] = json.dumps( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/__init__.py", line 231, in dumps return _default_encoder.encode(obj) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 199, in encode chunks = self.iterencode(o, _one_shot=True) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 257, in iterencode return _iterencode(o, 0) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 179, in default raise TypeError(f'Object of type {o.__class__.__name__} ' TypeError: Object of type RepeatedComposite is not JSON serializable ``` **Twitter handle:** @abehsu1992626 --- .../langchain_google_vertexai/chat_models.py | 7 +- .../tests/unit_tests/test_chat_models.py | 114 ++++++++++++++++++ 2 files changed, 119 insertions(+), 2 deletions(-) diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py index 32255d268b7..48709faebd7 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/chat_models.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, Iterator, List, Optional, Union, cast from urllib.parse import urlparse +import proto # type: ignore[import-untyped] import requests from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall @@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage: first_part = response_candidate.content.parts[0] if first_part.function_call: function_call = {"name": first_part.function_call.name} - # dump to match other function calling llm for now + function_call_args_dict = proto.Message.to_dict(first_part.function_call)[ + "args" + ] function_call["arguments"] = json.dumps( - {k: first_part.function_call.args[k] for k in first_part.function_call.args} + {k: function_call_args_dict[k] for k in function_call_args_dict} ) additional_kwargs["function_call"] = function_call return AIMessage(content=content, additional_kwargs=additional_kwargs) diff --git a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py index 6bcbb6e5abf..f24f418ba62 100644 --- a/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-vertexai/tests/unit_tests/test_chat_models.py @@ -1,22 +1,35 @@ """Test chat model integration.""" +import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch import pytest +from google.cloud.aiplatform_v1beta1.types import ( + Content, + FunctionCall, + Part, +) +from google.cloud.aiplatform_v1beta1.types import ( + content as gapic_content_types, +) from langchain_core.messages import ( AIMessage, HumanMessage, SystemMessage, ) from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore +from vertexai.preview.generative_models import ( # type: ignore + Candidate, +) from langchain_google_vertexai.chat_models import ( ChatVertexAI, _parse_chat_history, _parse_chat_history_gemini, _parse_examples, + _parse_response_candidate, ) @@ -202,3 +215,104 @@ def test_default_params_gemini() -> None: message = HumanMessage(content=user_prompt) _ = model([message]) mock_start_chat.assert_called_once_with(history=[]) + + +@pytest.mark.parametrize( + "raw_candidate, expected", + [ + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"name": "Ben"}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"name": "Ben"}, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"info": ["A", "B", "C"]}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"info": ["A", "B", "C"]}, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={ + "people": [ + {"name": "Joe", "age": 30}, + {"name": "Martha"}, + ] + }, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": { + "people": [ + {"name": "Joe", "age": 30}, + {"name": "Martha"}, + ] + }, + }, + ), + ( + gapic_content_types.Candidate( + content=Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="Information", + args={"info": [[1, 2, 3], [4, 5, 6]]}, + ), + ) + ], + ) + ), + { + "name": "Information", + "arguments": {"info": [[1, 2, 3], [4, 5, 6]]}, + }, + ), + ], +) +def test_parse_response_candidate(raw_candidate, expected) -> None: + response_candidate = Candidate._from_gapic(raw_candidate) + result = _parse_response_candidate(response_candidate) + result_arguments = json.loads( + result.additional_kwargs["function_call"]["arguments"] + ) + + assert result_arguments == expected["arguments"]