mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-24 05:50:18 +00:00
google-vertexai[patch]: fix _parse_response_candidate issue (#16647)
**Description:** enable _parse_response_candidate to support complex
structure format.
**Issue:**
currently, if Gemini response complex args format, people will get
"TypeError: Object of type RepeatedComposite is not JSON serializable"
error from _parse_response_candidate.
response candidate example
```
content {
role: "model"
parts {
function_call {
name: "Information"
args {
fields {
key: "people"
value {
list_value {
values {
string_value: "Joe is 30, his mom is Martha"
}
}
}
}
}
}
}
}
finish_reason: STOP
safety_ratings {
category: HARM_CATEGORY_HARASSMENT
probability: NEGLIGIBLE
}
safety_ratings {
category: HARM_CATEGORY_HATE_SPEECH
probability: NEGLIGIBLE
}
safety_ratings {
category: HARM_CATEGORY_SEXUALLY_EXPLICIT
probability: NEGLIGIBLE
}
safety_ratings {
category: HARM_CATEGORY_DANGEROUS_CONTENT
probability: NEGLIGIBLE
}
```
error msg:
```
Traceback (most recent call last):
File "/home/jupyter/user/abehsu/gemini_langchain_tools/example2.py", line 36, in <module>
print(tagging_chain.invoke({"input": "Joe is 30, his mom is Martha"}))
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 2053, in invoke
input = step.invoke(
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/runnables/base.py", line 3887, in invoke
return self.bound.invoke(
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 165, in invoke
self.generate_prompt(
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 543, in generate_prompt
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 407, in generate
raise e
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 397, in generate
self._generate_with_cache(
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py", line 576, in _generate_with_cache
return self._generate(
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 406, in _generate
generations = [
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 408, in <listcomp>
message=_parse_response_candidate(c),
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/site-packages/langchain_google_vertexai/chat_models.py", line 280, in _parse_response_candidate
function_call["arguments"] = json.dumps(
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/__init__.py", line 231, in dumps
return _default_encoder.encode(obj)
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 199, in encode
chunks = self.iterencode(o, _one_shot=True)
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 257, in iterencode
return _iterencode(o, 0)
File "/opt/conda/envs/gemini_langchain_tools/lib/python3.10/json/encoder.py", line 179, in default
raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type RepeatedComposite is not JSON serializable
```
**Twitter handle:** @abehsu1992626
This commit is contained in:
@@ -9,6 +9,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import proto # type: ignore[import-untyped]
|
||||
import requests
|
||||
from google.cloud.aiplatform_v1beta1.types.content import Part as GapicPart
|
||||
from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall
|
||||
@@ -278,10 +279,12 @@ def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
|
||||
first_part = response_candidate.content.parts[0]
|
||||
if first_part.function_call:
|
||||
function_call = {"name": first_part.function_call.name}
|
||||
|
||||
# dump to match other function calling llm for now
|
||||
function_call_args_dict = proto.Message.to_dict(first_part.function_call)[
|
||||
"args"
|
||||
]
|
||||
function_call["arguments"] = json.dumps(
|
||||
{k: first_part.function_call.args[k] for k in first_part.function_call.args}
|
||||
{k: function_call_args_dict[k] for k in function_call_args_dict}
|
||||
)
|
||||
additional_kwargs["function_call"] = function_call
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
@@ -1,22 +1,35 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
Content,
|
||||
FunctionCall,
|
||||
Part,
|
||||
)
|
||||
from google.cloud.aiplatform_v1beta1.types import (
|
||||
content as gapic_content_types,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from vertexai.language_models import ChatMessage, InputOutputTextPair # type: ignore
|
||||
from vertexai.preview.generative_models import ( # type: ignore
|
||||
Candidate,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai.chat_models import (
|
||||
ChatVertexAI,
|
||||
_parse_chat_history,
|
||||
_parse_chat_history_gemini,
|
||||
_parse_examples,
|
||||
_parse_response_candidate,
|
||||
)
|
||||
|
||||
|
||||
@@ -202,3 +215,104 @@ def test_default_params_gemini() -> None:
|
||||
message = HumanMessage(content=user_prompt)
|
||||
_ = model([message])
|
||||
mock_start_chat.assert_called_once_with(history=[])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw_candidate, expected",
|
||||
[
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={"name": "Ben"},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {"name": "Ben"},
|
||||
},
|
||||
),
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={"info": ["A", "B", "C"]},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {"info": ["A", "B", "C"]},
|
||||
},
|
||||
),
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={
|
||||
"people": [
|
||||
{"name": "Joe", "age": 30},
|
||||
{"name": "Martha"},
|
||||
]
|
||||
},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {
|
||||
"people": [
|
||||
{"name": "Joe", "age": 30},
|
||||
{"name": "Martha"},
|
||||
]
|
||||
},
|
||||
},
|
||||
),
|
||||
(
|
||||
gapic_content_types.Candidate(
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
function_call=FunctionCall(
|
||||
name="Information",
|
||||
args={"info": [[1, 2, 3], [4, 5, 6]]},
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
{
|
||||
"name": "Information",
|
||||
"arguments": {"info": [[1, 2, 3], [4, 5, 6]]},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_response_candidate(raw_candidate, expected) -> None:
|
||||
response_candidate = Candidate._from_gapic(raw_candidate)
|
||||
result = _parse_response_candidate(response_candidate)
|
||||
result_arguments = json.loads(
|
||||
result.additional_kwargs["function_call"]["arguments"]
|
||||
)
|
||||
|
||||
assert result_arguments == expected["arguments"]
|
||||
|
||||
Reference in New Issue
Block a user