mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
langchain-google-vertexai: perserving grounding metadata (#16309)
Revival of https://github.com/langchain-ai/langchain/pull/14549 that closes https://github.com/langchain-ai/langchain/issues/14548.
This commit is contained in:
parent
adc008407e
commit
0785432e7b
@ -1,4 +1,6 @@
|
||||
"""Utilities to init Vertex AI."""
|
||||
|
||||
import dataclasses
|
||||
from importlib import metadata
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
@ -10,7 +12,13 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import create_base_retry_decorator
|
||||
from vertexai.preview.generative_models import Image # type: ignore
|
||||
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
|
||||
Candidate,
|
||||
)
|
||||
from vertexai.language_models import ( # type: ignore[import-untyped]
|
||||
TextGenerationResponse,
|
||||
)
|
||||
from vertexai.preview.generative_models import Image # type: ignore[import-untyped]
|
||||
|
||||
|
||||
def create_retry_decorator(
|
||||
@ -88,27 +96,23 @@ def is_gemini_model(model_name: str) -> bool:
|
||||
return model_name is not None and "gemini" in model_name
|
||||
|
||||
|
||||
def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
if is_gemini:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||
return {
|
||||
"is_blocked": any(
|
||||
[rating.blocked for rating in candidate.safety_ratings]
|
||||
),
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": rating.category.name,
|
||||
"probability_label": rating.probability.name,
|
||||
}
|
||||
for rating in candidate.safety_ratings
|
||||
],
|
||||
}
|
||||
else:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||
return {
|
||||
"is_blocked": candidate.is_blocked,
|
||||
"safety_attributes": candidate.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
def get_generation_info(
|
||||
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
|
||||
) -> Dict[str, Any]:
|
||||
if is_gemini:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||
return {
|
||||
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": rating.category.name,
|
||||
"probability_label": rating.probability.name,
|
||||
}
|
||||
for rating in candidate.safety_ratings
|
||||
],
|
||||
"citation_metadata": candidate.citation_metadata,
|
||||
}
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||
candidate_dc = dataclasses.asdict(candidate)
|
||||
candidate_dc.pop("text")
|
||||
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Test ChatGoogleVertexAI chat model."""
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
@ -16,7 +16,7 @@ model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_initialization(model_name: str) -> None:
|
||||
def test_initialization(model_name: Optional[str]) -> None:
|
||||
"""Test chat model initialization."""
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
@ -30,7 +30,7 @@ def test_initialization(model_name: str) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_single_call(model_name: str) -> None:
|
||||
def test_vertexai_single_call(model_name: Optional[str]) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
@ -164,7 +164,7 @@ def test_vertexai_single_call_with_examples() -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_single_call_with_history(model_name: str) -> None:
|
||||
def test_vertexai_single_call_with_history(model_name: Optional[str]) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
@ -203,7 +203,7 @@ def test_chat_vertexai_gemini_system_message_error(model_name: str) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_chat_vertexai_system_message(model_name: str) -> None:
|
||||
def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(
|
||||
model_name=model_name, convert_system_message_to_human=True
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test chat model integration."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -45,6 +47,13 @@ def test_parse_examples_failes_wrong_sequence() -> None:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubTextChatResponse:
|
||||
"""Stub text-chat response from VertexAI for testing."""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stop", [None, "stop1"])
|
||||
def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
response_text = "Goodbye"
|
||||
@ -59,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
# Mock the library to ensure the args are passed correctly
|
||||
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [Mock(text=response_text)]
|
||||
mock_response.candidates = [StubTextChatResponse(text=response_text)]
|
||||
mock_chat = MagicMock()
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
mock_chat.send_message = mock_send_message
|
||||
@ -136,7 +145,7 @@ def test_default_params_palm() -> None:
|
||||
|
||||
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [Mock(text="Goodbye")]
|
||||
mock_response.candidates = [StubTextChatResponse(text="Goodbye")]
|
||||
mock_chat = MagicMock()
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
mock_chat.send_message = mock_send_message
|
||||
@ -159,13 +168,28 @@ def test_default_params_palm() -> None:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubGeminiResponse:
|
||||
"""Stub gemini response from VertexAI for testing."""
|
||||
|
||||
text: str
|
||||
content: Any
|
||||
citation_metadata: Any
|
||||
safety_ratings: List[Any] = field(default_factory=list)
|
||||
|
||||
|
||||
def test_default_params_gemini() -> None:
|
||||
user_prompt = "Hello"
|
||||
|
||||
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
||||
mock_response = MagicMock()
|
||||
content = Mock(parts=[Mock(function_call=None)])
|
||||
mock_response.candidates = [Mock(text="Goodbye", content=content)]
|
||||
mock_response.candidates = [
|
||||
StubGeminiResponse(
|
||||
text="Goodbye",
|
||||
content=Mock(parts=[Mock(function_call=None)]),
|
||||
citation_metadata=Mock(),
|
||||
)
|
||||
]
|
||||
mock_chat = MagicMock()
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
mock_chat.send_message = mock_send_message
|
||||
|
Loading…
Reference in New Issue
Block a user