google-vertexai[patch]: fix _parse_response_candidate issue (#16647)

**Description:** enable _parse_response_candidate to support complex
structure format.
  **Issue:** 
currently, if Gemini response complex args format, people will get
"TypeError: Object of type RepeatedComposite is not JSON serializable"
error from _parse_response_candidate.
  
 response candidate example
```
content {
  role: "model"
  parts {
    function_call {
      name: "Information"
      args {
        fields {
          key: "people"
          value {
            list_value {
              values {
                string_value: "Joe is 30, his mom is Martha"
              }
            }
          }
        }
      }
    }
  }
}
finish_reason: STOP
safety_ratings {
  category: HARM_CATEGORY_HARASSMENT
  probability: NEGLIGIBLE
}
safety_ratings {
  category: HARM_CATEGORY_HATE_SPEECH
  probability: NEGLIGIBLE
}
safety_ratings {
  category: HARM_CATEGORY_SEXUALLY_EXPLICIT
  probability: NEGLIGIBLE
}
safety_ratings {
  category: HARM_CATEGORY_DANGEROUS_CONTENT
  probability: NEGLIGIBLE
}
```
 
error msg:
```
Traceback (most recent call last):
  File "/home/jupyter/user/abehsu/gemini_langchain_tools/example2.py", line 36, in <module>
    print(tagging_chain.invoke({"input": "Joe is 30, his mom is Martha"}))
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2053, in invoke
    input = step.invoke(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3887, in invoke
    return self.bound.invoke(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 165, in invoke
    self.generate_prompt(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 543, in generate_prompt
    return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 407, in generate
    raise e
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 397, in generate
    self._generate_with_cache(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 576, in _generate_with_cache
    return self._generate(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 406, in _generate
    generations = [
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 408, in <listcomp>
    message=_parse_response_candidate(c),
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 280, in _parse_response_candidate
    function_call["arguments"] = json.dumps(
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/__init__.py", line 231, in dumps
    return _default_encoder.encode(obj)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
  File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 179, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type RepeatedComposite is not JSON serializable
```
  

  **Twitter handle:**  @abehsu1992626
This commit is contained in:
hsuyuming 2024-02-08 12:48:25 -07:00 committed by GitHub
parent d77bb7b4e9
commit e22c4d4eb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 119 additions and 2 deletions

View File

@ -9,6 +9,7 @@ from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Union, cast from typing import Any, Dict, Iterator, List, Optional, Union, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import proto # type: ignore[import-untyped]
import requests import requests
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
first_part = response_candidate.content.parts[0] first_part = response_candidate.content.parts[0]
if first_part.function_call: if first_part.function_call:
function_call = {"name": first_part.function_call.name} function_call = {"name": first_part.function_call.name}
# dump to match other function calling llm for now # dump to match other function calling llm for now
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
"args"
]
function_call["arguments"] = json.dumps( function_call["arguments"] = json.dumps(
{k: first_part.function_call.args[k] for k in first_part.function_call.args} {k: function_call_args_dict[k] for k in function_call_args_dict}
) )
additional_kwargs["function_call"] = function_call additional_kwargs["function_call"] = function_call
return AIMessage(content=content, additional_kwargs=additional_kwargs) return AIMessage(content=content, additional_kwargs=additional_kwargs)

View File

@ -1,22 +1,35 @@
"""Test chat model integration.""" """Test chat model integration."""
import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from google.cloud.aiplatform_v1beta1.types import (
Content,
FunctionCall,
Part,
)
from google.cloud.aiplatform_v1beta1.types import (
content as gapic_content_types,
)
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
)
from langchain_google_vertexai.chat_models import ( from langchain_google_vertexai.chat_models import (
ChatVertexAI, ChatVertexAI,
_parse_chat_history, _parse_chat_history,
_parse_chat_history_gemini, _parse_chat_history_gemini,
_parse_examples, _parse_examples,
_parse_response_candidate,
) )
@ -202,3 +215,104 @@ def test_default_params_gemini() -> None:
message = HumanMessage(content=user_prompt) message = HumanMessage(content=user_prompt)
_ = model([message]) _ = model([message])
mock_start_chat.assert_called_once_with(history=[]) mock_start_chat.assert_called_once_with(history=[])
@pytest.mark.parametrize(
"raw_candidate, expected",
[
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"name": "Ben"},
),
)
],
)
),
{
"name": "Information",
"arguments": {"name": "Ben"},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": ["A", "B", "C"]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": ["A", "B", "C"]},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
),
)
],
)
),
{
"name": "Information",
"arguments": {
"people": [
{"name": "Joe", "age": 30},
{"name": "Martha"},
]
},
},
),
(
gapic_content_types.Candidate(
content=Content(
role="model",
parts=[
Part(
function_call=FunctionCall(
name="Information",
args={"info": [[1, 2, 3], [4, 5, 6]]},
),
)
],
)
),
{
"name": "Information",
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
},
),
],
)
def test_parse_response_candidate(raw_candidate, expected) -> None:
response_candidate = Candidate._from_gapic(raw_candidate)
result = _parse_response_candidate(response_candidate)
result_arguments = json.loads(
result.additional_kwargs["function_call"]["arguments"]
)
assert result_arguments == expected["arguments"]