mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
google-vertexai[patch]: fix _parse_response_candidate issue (#16647)
**Description:** enable _parse_response_candidate to support complex structure format. **Issue:** currently, if Gemini response complex args format, people will get "TypeError: Object of type RepeatedComposite is not JSON serializable" error from _parse_response_candidate. response candidate example ``` content { role: "model" parts { function_call { name: "Information" args { fields { key: "people" value { list_value { values { string_value: "Joe is 30, his mom is Martha" } } } } } } } } finish_reason: STOP safety_ratings { category: HARM_CATEGORY_HARASSMENT probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_HATE_SPEECH probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_SEXUALLY_EXPLICIT probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_DANGEROUS_CONTENT probability: NEGLIGIBLE } ``` error msg: ``` Traceback (most recent call last): File "/home/jupyter/user/abehsu/gemini_langchain_tools/example2.py", line 36, in <module> print(tagging_chain.invoke({"input": "Joe is 30, his mom is Martha"})) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2053, in invoke input = step.invoke( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3887, in invoke return self.bound.invoke( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 165, in invoke self.generate_prompt( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 543, in generate_prompt return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 407, in generate raise e File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 397, in generate self._generate_with_cache( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 576, in _generate_with_cache return self._generate( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 406, in _generate generations = [ File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 408, in <listcomp> message=_parse_response_candidate(c), File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 280, in _parse_response_candidate function_call["arguments"] = json.dumps( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/__init__.py", line 231, in dumps return _default_encoder.encode(obj) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 199, in encode chunks = self.iterencode(o, _one_shot=True) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 257, in iterencode return _iterencode(o, 0) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 179, in default raise TypeError(f'Object of type {o.__class__.__name__} ' TypeError: Object of type RepeatedComposite is not JSON serializable ``` **Twitter handle:** @abehsu1992626
This commit is contained in:
parent
d77bb7b4e9
commit
e22c4d4eb0
@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import proto # type: ignore[import-untyped]
|
||||
import requests
|
||||
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
|
||||
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
|
||||
@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
|
||||
first_part = response_candidate.content.parts[0]
|
||||
if first_part.function_call:
|
||||
function_call = {"name": first_part.function_call.name}
|
||||
|
||||
# dump to match other function calling llm for now
|
||||
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
|
||||
"args"
|
||||
]
|
||||
function_call["arguments"] = json.dumps(
|
||||
{k: first_part.function_call.args[k] for k in first_part.function_call.args}
|
||||
{k: function_call_args_dict[k] for k in function_call_args_dict}
|
||||
)
|
||||
additional_kwargs["function_call"] = function_call
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
@ -1,22 +1,35 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
Content,
|
||||
FunctionCall,
|
||||
Part,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
content as gapic_content_types,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
|
||||
from vertexai.preview.generative_models import ( # type: ignore
|
||||
Candidate,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai.chat_models import (
|
||||
ChatVertexAI,
|
||||
_parse_chat_history,
|
||||
_parse_chat_history_gemini,
|
||||
_parse_examples,
|
||||
_parse_response_candidate,
|
||||
)
|
||||
|
||||
|
||||
@ -202,3 +215,104 @@ def test_default_params_gemini() -> None:
|
||||
message = HumanMessage(content=user_prompt)
|
||||
_ = model([message])
|
||||
mock_start_chat.assert_called_once_with(history=[])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_candidate, expected",
|
||||
[
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={"name": "Ben"},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {"name": "Ben"},
|
||||
},
|
||||
),
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={"info": ["A", "B", "C"]},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {"info": ["A", "B", "C"]},
|
||||
},
|
||||
),
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={
|
||||
"people": [
|
||||
{"name": "Joe", "age": 30},
|
||||
{"name": "Martha"},
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {
|
||||
"people": [
|
||||
{"name": "Joe", "age": 30},
|
||||
{"name": "Martha"},
|
||||
]
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={"info": [[1, 2, 3], [4, 5, 6]]},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_response_candidate(raw_candidate, expected) -> None:
|
||||
response_candidate = Candidate._from_gapic(raw_candidate)
|
||||
result = _parse_response_candidate(response_candidate)
|
||||
result_arguments = json.loads(
|
||||
result.additional_kwargs["function_call"]["arguments"]
|
||||
)
|
||||
|
||||
assert result_arguments == expected["arguments"]
|
||||
|
Loading…
Reference in New Issue
Block a user