mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 01:13:48 +00:00
google-vertexai[patch]: fix _parse_response_candidate issue (#16647)
**Description:** enable _parse_response_candidate to support complex structure format. **Issue:** currently, if Gemini response complex args format, people will get "TypeError: Object of type RepeatedComposite is not JSON serializable" error from _parse_response_candidate. response candidate example ``` content { role: "model" parts { function_call { name: "Information" args { fields { key: "people" value { list_value { values { string_value: "Joe is 30, his mom is Martha" } } } } } } } } finish_reason: STOP safety_ratings { category: HARM_CATEGORY_HARASSMENT probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_HATE_SPEECH probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_SEXUALLY_EXPLICIT probability: NEGLIGIBLE } safety_ratings { category: HARM_CATEGORY_DANGEROUS_CONTENT probability: NEGLIGIBLE } ``` error msg: ``` Traceback (most recent call last): File "/home/jupyter/user/abehsu/gemini_langchain_tools/example2.py", line 36, in <module> print(tagging_chain.invoke({"input": "Joe is 30, his mom is Martha"})) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2053, in invoke input = step.invoke( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3887, in invoke return self.bound.invoke( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 165, in invoke self.generate_prompt( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 543, in generate_prompt return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 407, in generate raise e File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 397, in generate self._generate_with_cache( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 576, in _generate_with_cache return self._generate( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 406, in _generate generations = [ File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 408, in <listcomp> message=_parse_response_candidate(c), File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 280, in _parse_response_candidate function_call["arguments"] = json.dumps( File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/__init__.py", line 231, in dumps return _default_encoder.encode(obj) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 199, in encode chunks = self.iterencode(o, _one_shot=True) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 257, in iterencode return _iterencode(o, 0) File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 179, in default raise TypeError(f'Object of type {o.__class__.__name__} ' TypeError: Object of type RepeatedComposite is not JSON serializable ``` **Twitter handle:** @abehsu1992626
This commit is contained in:
parent
d77bb7b4e9
commit
e22c4d4eb0
@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, Iterator, List, Optional, Union, cast
|
from typing import Any, Dict, Iterator, List, Optional, Union, cast
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import proto # type: ignore[import-untyped]
|
||||||
import requests
|
import requests
|
||||||
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
|
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
|
||||||
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
|
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
|
||||||
@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
|
|||||||
first_part = response_candidate.content.parts[0]
|
first_part = response_candidate.content.parts[0]
|
||||||
if first_part.function_call:
|
if first_part.function_call:
|
||||||
function_call = {"name": first_part.function_call.name}
|
function_call = {"name": first_part.function_call.name}
|
||||||
|
|
||||||
# dump to match other function calling llm for now
|
# dump to match other function calling llm for now
|
||||||
|
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
|
||||||
|
"args"
|
||||||
|
]
|
||||||
function_call["arguments"] = json.dumps(
|
function_call["arguments"] = json.dumps(
|
||||||
{k: first_part.function_call.args[k] for k in first_part.function_call.args}
|
{k: function_call_args_dict[k] for k in function_call_args_dict}
|
||||||
)
|
)
|
||||||
additional_kwargs["function_call"] = function_call
|
additional_kwargs["function_call"] = function_call
|
||||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
@ -1,22 +1,35 @@
|
|||||||
"""Test chat model integration."""
|
"""Test chat model integration."""
|
||||||
|
|
||||||
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import (
|
||||||
|
Content,
|
||||||
|
FunctionCall,
|
||||||
|
Part,
|
||||||
|
)
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import (
|
||||||
|
content as gapic_content_types,
|
||||||
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
|
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
|
||||||
|
from vertexai.preview.generative_models import ( # type: ignore
|
||||||
|
Candidate,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_google_vertexai.chat_models import (
|
from langchain_google_vertexai.chat_models import (
|
||||||
ChatVertexAI,
|
ChatVertexAI,
|
||||||
_parse_chat_history,
|
_parse_chat_history,
|
||||||
_parse_chat_history_gemini,
|
_parse_chat_history_gemini,
|
||||||
_parse_examples,
|
_parse_examples,
|
||||||
|
_parse_response_candidate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -202,3 +215,104 @@ def test_default_params_gemini() -> None:
|
|||||||
message = HumanMessage(content=user_prompt)
|
message = HumanMessage(content=user_prompt)
|
||||||
_ = model([message])
|
_ = model([message])
|
||||||
mock_start_chat.assert_called_once_with(history=[])
|
mock_start_chat.assert_called_once_with(history=[])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw_candidate, expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
gapic_content_types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
function_call=FunctionCall(
|
||||||
|
name="Information",
|
||||||
|
args={"name": "Ben"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"name": "Information",
|
||||||
|
"arguments": {"name": "Ben"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
gapic_content_types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
function_call=FunctionCall(
|
||||||
|
name="Information",
|
||||||
|
args={"info": ["A", "B", "C"]},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"name": "Information",
|
||||||
|
"arguments": {"info": ["A", "B", "C"]},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
gapic_content_types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
function_call=FunctionCall(
|
||||||
|
name="Information",
|
||||||
|
args={
|
||||||
|
"people": [
|
||||||
|
{"name": "Joe", "age": 30},
|
||||||
|
{"name": "Martha"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"name": "Information",
|
||||||
|
"arguments": {
|
||||||
|
"people": [
|
||||||
|
{"name": "Joe", "age": 30},
|
||||||
|
{"name": "Martha"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
gapic_content_types.Candidate(
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
function_call=FunctionCall(
|
||||||
|
name="Information",
|
||||||
|
args={"info": [[1, 2, 3], [4, 5, 6]]},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"name": "Information",
|
||||||
|
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_response_candidate(raw_candidate, expected) -> None:
|
||||||
|
response_candidate = Candidate._from_gapic(raw_candidate)
|
||||||
|
result = _parse_response_candidate(response_candidate)
|
||||||
|
result_arguments = json.loads(
|
||||||
|
result.additional_kwargs["function_call"]["arguments"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result_arguments == expected["arguments"]
|
||||||
|
Loading…
Reference in New Issue
Block a user