mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
langchain-google-vertexai: perserving grounding metadata (#16309)
Revival of https://github.com/langchain-ai/langchain/pull/14549 that closes https://github.com/langchain-ai/langchain/issues/14548.
This commit is contained in:
parent
adc008407e
commit
0785432e7b
@ -1,4 +1,6 @@
|
|||||||
"""Utilities to init Vertex AI."""
|
"""Utilities to init Vertex AI."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
@ -10,7 +12,13 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||||
from vertexai.preview.generative_models import Image # type: ignore
|
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
|
||||||
|
Candidate,
|
||||||
|
)
|
||||||
|
from vertexai.language_models import ( # type: ignore[import-untyped]
|
||||||
|
TextGenerationResponse,
|
||||||
|
)
|
||||||
|
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
|
||||||
def create_retry_decorator(
|
def create_retry_decorator(
|
||||||
@ -88,27 +96,23 @@ def is_gemini_model(model_name: str) -> bool:
|
|||||||
return model_name is not None and "gemini" in model_name
|
return model_name is not None and "gemini" in model_name
|
||||||
|
|
||||||
|
|
||||||
def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]:
|
def get_generation_info(
|
||||||
try:
|
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
|
||||||
if is_gemini:
|
) -> Dict[str, Any]:
|
||||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
if is_gemini:
|
||||||
return {
|
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||||
"is_blocked": any(
|
return {
|
||||||
[rating.blocked for rating in candidate.safety_ratings]
|
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
|
||||||
),
|
"safety_ratings": [
|
||||||
"safety_ratings": [
|
{
|
||||||
{
|
"category": rating.category.name,
|
||||||
"category": rating.category.name,
|
"probability_label": rating.probability.name,
|
||||||
"probability_label": rating.probability.name,
|
}
|
||||||
}
|
for rating in candidate.safety_ratings
|
||||||
for rating in candidate.safety_ratings
|
],
|
||||||
],
|
"citation_metadata": candidate.citation_metadata,
|
||||||
}
|
}
|
||||||
else:
|
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
candidate_dc = dataclasses.asdict(candidate)
|
||||||
return {
|
candidate_dc.pop("text")
|
||||||
"is_blocked": candidate.is_blocked,
|
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
|
||||||
"safety_attributes": candidate.safety_attributes,
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Test ChatGoogleVertexAI chat model."""
|
"""Test ChatGoogleVertexAI chat model."""
|
||||||
from typing import cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -16,7 +16,7 @@ model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||||
def test_initialization(model_name: str) -> None:
|
def test_initialization(model_name: Optional[str]) -> None:
|
||||||
"""Test chat model initialization."""
|
"""Test chat model initialization."""
|
||||||
if model_name:
|
if model_name:
|
||||||
model = ChatVertexAI(model_name=model_name)
|
model = ChatVertexAI(model_name=model_name)
|
||||||
@ -30,7 +30,7 @@ def test_initialization(model_name: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||||
def test_vertexai_single_call(model_name: str) -> None:
|
def test_vertexai_single_call(model_name: Optional[str]) -> None:
|
||||||
if model_name:
|
if model_name:
|
||||||
model = ChatVertexAI(model_name=model_name)
|
model = ChatVertexAI(model_name=model_name)
|
||||||
else:
|
else:
|
||||||
@ -164,7 +164,7 @@ def test_vertexai_single_call_with_examples() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||||
def test_vertexai_single_call_with_history(model_name: str) -> None:
|
def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
|
||||||
if model_name:
|
if model_name:
|
||||||
model = ChatVertexAI(model_name=model_name)
|
model = ChatVertexAI(model_name=model_name)
|
||||||
else:
|
else:
|
||||||
@ -203,7 +203,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||||
def test_chat_vertexai_system_message(model_name: str) -> None:
|
def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
|
||||||
if model_name:
|
if model_name:
|
||||||
model = ChatVertexAI(
|
model = ChatVertexAI(
|
||||||
model_name=model_name, convert_system_message_to_human=True
|
model_name=model_name, convert_system_message_to_human=True
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test chat model integration."""
|
"""Test chat model integration."""
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubTextChatResponse:
|
||||||
|
"""Stub text-chat response from VertexAI for testing."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("stop", [None, "stop1"])
|
@pytest.mark.parametrize("stop", [None, "stop1"])
|
||||||
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||||
response_text = "Goodbye"
|
response_text = "Goodbye"
|
||||||
@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|||||||
# Mock the library to ensure the args are passed correctly
|
# Mock the library to ensure the args are passed correctly
|
||||||
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.candidates = [Mock(text=response_text)]
|
mock_response.candidates = [StubTextChatResponse(text=response_text)]
|
||||||
mock_chat = MagicMock()
|
mock_chat = MagicMock()
|
||||||
mock_send_message = MagicMock(return_value=mock_response)
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
mock_chat.send_message = mock_send_message
|
mock_chat.send_message = mock_send_message
|
||||||
@ -136,7 +145,7 @@ def test_default_params_palm() -> None:
|
|||||||
|
|
||||||
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.candidates = [Mock(text="Goodbye")]
|
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
|
||||||
mock_chat = MagicMock()
|
mock_chat = MagicMock()
|
||||||
mock_send_message = MagicMock(return_value=mock_response)
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
mock_chat.send_message = mock_send_message
|
mock_chat.send_message = mock_send_message
|
||||||
@ -159,13 +168,28 @@ def test_default_params_palm() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubGeminiResponse:
|
||||||
|
"""Stub gemini response from VertexAI for testing."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
content: Any
|
||||||
|
citation_metadata: Any
|
||||||
|
safety_ratings: List[Any] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def test_default_params_gemini() -> None:
|
def test_default_params_gemini() -> None:
|
||||||
user_prompt = "Hello"
|
user_prompt = "Hello"
|
||||||
|
|
||||||
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
content = Mock(parts=[Mock(function_call=None)])
|
mock_response.candidates = [
|
||||||
mock_response.candidates = [Mock(text="Goodbye", content=content)]
|
StubGeminiResponse(
|
||||||
|
text="Goodbye",
|
||||||
|
content=Mock(parts=[Mock(function_call=None)]),
|
||||||
|
citation_metadata=Mock(),
|
||||||
|
)
|
||||||
|
]
|
||||||
mock_chat = MagicMock()
|
mock_chat = MagicMock()
|
||||||
mock_send_message = MagicMock(return_value=mock_response)
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
mock_chat.send_message = mock_send_message
|
mock_chat.send_message = mock_send_message
|
||||||
|
Loading…
Reference in New Issue
Block a user