mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
google-vertexai[minor]: added safety_settings property to gemini wrapper (#15344)
**Description:** Gemini model has quite annoying default safety_settings settings. In addition, current VertexAI class doesn't provide a property to override such settings. So, this PR aims to - add safety_settings property to VertexAI - fix issue with incorrect LLM output parsing when LLM responds with appropriate 'blocked' response - fix issue with incorrect parsing LLM output when Gemini API blocks prompt itself as inappropriate - add safety_settings related tests I'm not enough familiar with langchain code base and guidelines. So, any comments and/or suggestions are very welcome. **Issue:** it will likely fix #14841 --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
ecd4f0a7ec
commit
6b9e3ed9e9
@ -35,7 +35,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@ -44,10 +44,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
"^C\n",
|
||||
"\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n",
|
||||
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -57,7 +56,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -67,7 +66,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -76,7 +75,7 @@
|
||||
"AIMessage(content=\" J'aime la programmation.\")"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -101,7 +100,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -110,7 +109,7 @@
|
||||
"AIMessage(content=' プログラミングが大好きです')"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -154,7 +153,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@ -165,27 +164,51 @@
|
||||
"text": [
|
||||
" ```python\n",
|
||||
"def is_prime(n):\n",
|
||||
" if n <= 1:\n",
|
||||
" return False\n",
|
||||
" for i in range(2, n):\n",
|
||||
" if n % i == 0:\n",
|
||||
" return False\n",
|
||||
" return True\n",
|
||||
" \"\"\"\n",
|
||||
" Check if a number is prime.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" n: The number to check.\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" True if n is prime, False otherwise.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # If n is 1, it is not prime.\n",
|
||||
" if n == 1:\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" # Iterate over all numbers from 2 to the square root of n.\n",
|
||||
" for i in range(2, int(n ** 0.5) + 1):\n",
|
||||
" # If n is divisible by any number from 2 to its square root, it is not prime.\n",
|
||||
" if n % i == 0:\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" # If n is divisible by no number from 2 to its square root, it is prime.\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def find_prime_numbers(n):\n",
|
||||
" prime_numbers = []\n",
|
||||
" for i in range(2, n + 1):\n",
|
||||
" if is_prime(i):\n",
|
||||
" prime_numbers.append(i)\n",
|
||||
" return prime_numbers\n",
|
||||
" \"\"\"\n",
|
||||
" Find all prime numbers up to a given number.\n",
|
||||
"\n",
|
||||
"print(find_prime_numbers(100))\n",
|
||||
"```\n",
|
||||
" Args:\n",
|
||||
" n: The upper bound for the prime numbers to find.\n",
|
||||
"\n",
|
||||
"Output:\n",
|
||||
" Returns:\n",
|
||||
" A list of all prime numbers up to n.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]\n",
|
||||
" # Create a list of all numbers from 2 to n.\n",
|
||||
" numbers = list(range(2, n + 1))\n",
|
||||
"\n",
|
||||
" # Iterate over the list of numbers and remove any that are not prime.\n",
|
||||
" for number in numbers:\n",
|
||||
" if not is_prime(number):\n",
|
||||
" numbers.remove(number)\n",
|
||||
"\n",
|
||||
" # Return the list of prime numbers.\n",
|
||||
" return numbers\n",
|
||||
"```\n"
|
||||
]
|
||||
}
|
||||
@ -199,6 +222,102 @@
|
||||
"print(message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Full generation info\n",
|
||||
"\n",
|
||||
"We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just chat completions\n",
|
||||
"\n",
|
||||
"Note that the `generation_info` will be different depending if you're using a gemini model or not.\n",
|
||||
"\n",
|
||||
"### Gemini model\n",
|
||||
"\n",
|
||||
"`generation_info` will include:\n",
|
||||
"\n",
|
||||
"- `is_blocked`: whether generation was blocked or not\n",
|
||||
"- `safety_ratings`: safety ratings' categories and probability labels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'is_blocked': False,\n",
|
||||
" 'safety_ratings': [{'category': 'HARM_CATEGORY_HARASSMENT',\n",
|
||||
" 'probability_label': 'NEGLIGIBLE'},\n",
|
||||
" {'category': 'HARM_CATEGORY_HATE_SPEECH',\n",
|
||||
" 'probability_label': 'NEGLIGIBLE'},\n",
|
||||
" {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n",
|
||||
" 'probability_label': 'NEGLIGIBLE'},\n",
|
||||
" {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n",
|
||||
" 'probability_label': 'NEGLIGIBLE'}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"from langchain_core.messages import HumanMessage\n",
|
||||
"from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory\n",
|
||||
"\n",
|
||||
"human = \"Translate this sentence from English to French. I love programming.\"\n",
|
||||
"messages = [HumanMessage(content=human)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"chat = ChatVertexAI(\n",
|
||||
" model_name=\"gemini-pro\",\n",
|
||||
" safety_settings={\n",
|
||||
" HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"result = chat.generate([messages])\n",
|
||||
"pprint(result.generations[0][0].generation_info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Non-gemini model\n",
|
||||
"\n",
|
||||
"`generation_info` will include:\n",
|
||||
"\n",
|
||||
"- `is_blocked`: whether generation was blocked or not\n",
|
||||
"- `safety_attributes`: a dictionary mapping safety attributes to their scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'is_blocked': False,\n",
|
||||
" 'safety_attributes': {'Derogatory': 0.1,\n",
|
||||
" 'Finance': 0.3,\n",
|
||||
" 'Insult': 0.1,\n",
|
||||
" 'Sexual': 0.1}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chat = ChatVertexAI() # default is `chat-bison`\n",
|
||||
"\n",
|
||||
"result = chat.generate([messages])\n",
|
||||
"pprint(result.generations[0][0].generation_info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -210,7 +329,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -224,7 +343,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -268,7 +387,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -1,5 +1,13 @@
|
||||
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
|
||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
||||
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
__all__ = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
|
||||
__all__ = [
|
||||
"ChatVertexAI",
|
||||
"VertexAIEmbeddings",
|
||||
"VertexAI",
|
||||
"VertexAIModelGarden",
|
||||
"HarmBlockThreshold",
|
||||
"HarmCategory",
|
||||
]
|
||||
|
@ -0,0 +1,6 @@
|
||||
from vertexai.preview.generative_models import ( # type: ignore
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
)
|
||||
|
||||
__all__ = ["HarmBlockThreshold", "HarmCategory"]
|
@ -1,6 +1,6 @@
|
||||
"""Utilities to init Vertex AI."""
|
||||
from importlib import metadata
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import google.api_core
|
||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||
@ -86,3 +86,29 @@ def is_codey_model(model_name: str) -> bool:
|
||||
def is_gemini_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Gemini model."""
|
||||
return model_name is not None and "gemini" in model_name
|
||||
|
||||
|
||||
def get_generation_info(candidate: Any, is_gemini: bool) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
if is_gemini:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||
return {
|
||||
"is_blocked": any(
|
||||
[rating.blocked for rating in candidate.safety_ratings]
|
||||
),
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": rating.category.name,
|
||||
"probability_label": rating.probability.name,
|
||||
}
|
||||
for rating in candidate.safety_ratings
|
||||
],
|
||||
}
|
||||
else:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||
return {
|
||||
"is_blocked": candidate.is_blocked,
|
||||
"safety_attributes": candidate.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -47,6 +47,9 @@ from vertexai.preview.generative_models import ( # type: ignore
|
||||
)
|
||||
|
||||
from langchain_google_vertexai._utils import (
|
||||
get_generation_info,
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
load_image_from_gcs,
|
||||
)
|
||||
from langchain_google_vertexai.functions_utils import (
|
||||
@ -54,8 +57,6 @@ from langchain_google_vertexai.functions_utils import (
|
||||
)
|
||||
from langchain_google_vertexai.llms import (
|
||||
_VertexAICommon,
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -271,9 +272,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
safety_settings = values["safety_settings"]
|
||||
|
||||
if safety_settings and not is_gemini:
|
||||
raise ValueError("Safety settings are only supported for Gemini models")
|
||||
|
||||
cls._init_vertexai(values)
|
||||
if is_gemini:
|
||||
values["client"] = GenerativeModel(model_name=values["model_name"])
|
||||
values["client"] = GenerativeModel(
|
||||
model_name=values["model_name"], safety_settings=safety_settings
|
||||
)
|
||||
else:
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
@ -306,6 +314,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
safety_settings = kwargs.pop("safety_settings", None)
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
@ -325,9 +334,17 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
# set param to `functions` until core tool/function calling implemented
|
||||
raw_tools = params.pop("functions") if "functions" in params else None
|
||||
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
|
||||
response = chat.send_message(message, generation_config=params, tools=tools)
|
||||
response = chat.send_message(
|
||||
message,
|
||||
generation_config=params,
|
||||
tools=tools,
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
generations = [
|
||||
ChatGeneration(message=_parse_response_candidate(c))
|
||||
ChatGeneration(
|
||||
message=_parse_response_candidate(c),
|
||||
generation_info=get_generation_info(c, self._is_gemini_model),
|
||||
)
|
||||
for c in response.candidates
|
||||
]
|
||||
else:
|
||||
@ -339,7 +356,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
ChatGeneration(
|
||||
message=AIMessage(content=r.text),
|
||||
generation_info=get_generation_info(r, self._is_gemini_model),
|
||||
)
|
||||
for r in response.candidates
|
||||
]
|
||||
return ChatResult(generations=generations)
|
||||
@ -370,6 +390,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
safety_settings = kwargs.pop("safety_settings", None)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
@ -382,22 +403,31 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
raw_tools = params.pop("functions") if "functions" in params else None
|
||||
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
|
||||
response = await chat.send_message_async(
|
||||
message, generation_config=params, tools=tools
|
||||
message,
|
||||
generation_config=params,
|
||||
tools=tools,
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
generations = [
|
||||
ChatGeneration(message=_parse_response_candidate(c))
|
||||
ChatGeneration(
|
||||
message=_parse_response_candidate(c),
|
||||
generation_info=get_generation_info(c, self._is_gemini_model),
|
||||
)
|
||||
for c in response.candidates
|
||||
]
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
examples = kwargs.get("examples", None) or self.examples
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
ChatGeneration(
|
||||
message=AIMessage(content=r.text),
|
||||
generation_info=get_generation_info(r, self._is_gemini_model),
|
||||
)
|
||||
for r in response.candidates
|
||||
]
|
||||
return ChatResult(generations=generations)
|
||||
@ -441,7 +471,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
for response in responses:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(response.text)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=response.text),
|
||||
generation_info=get_generation_info(response, self._is_gemini_model),
|
||||
)
|
||||
|
||||
def _start_chat(
|
||||
self, history: _ChatHistory, **kwargs: Any
|
||||
|
@ -26,7 +26,10 @@ from vertexai.language_models import ( # type: ignore
|
||||
from vertexai.language_models._language_models import ( # type: ignore
|
||||
TextGenerationResponse,
|
||||
)
|
||||
from vertexai.preview.generative_models import GenerativeModel, Image # type: ignore
|
||||
from vertexai.preview.generative_models import ( # type: ignore
|
||||
GenerativeModel,
|
||||
Image,
|
||||
)
|
||||
from vertexai.preview.language_models import ( # type: ignore
|
||||
CodeGenerationModel as PreviewCodeGenerationModel,
|
||||
)
|
||||
@ -34,9 +37,11 @@ from vertexai.preview.language_models import (
|
||||
TextGenerationModel as PreviewTextGenerationModel,
|
||||
)
|
||||
|
||||
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
|
||||
from langchain_google_vertexai._utils import (
|
||||
create_retry_decorator,
|
||||
get_client_info,
|
||||
get_generation_info,
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
)
|
||||
@ -66,7 +71,10 @@ def _completion_with_retry(
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
prompt, stream=stream, generation_config=kwargs
|
||||
prompt,
|
||||
stream=stream,
|
||||
safety_settings=kwargs.pop("safety_settings", None),
|
||||
generation_config=kwargs,
|
||||
)
|
||||
else:
|
||||
if stream:
|
||||
@ -94,7 +102,9 @@ async def _acompletion_with_retry(
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return await llm.client.generate_content_async(
|
||||
prompt, generation_config=kwargs
|
||||
prompt,
|
||||
generation_config=kwargs,
|
||||
safety_settings=kwargs.pop("safety_settings", None),
|
||||
)
|
||||
return await llm.client.predict_async(prompt, **kwargs)
|
||||
|
||||
@ -141,6 +151,21 @@ class _VertexAICommon(_VertexAIBase):
|
||||
"""How many completions to generate for each prompt."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
|
||||
"""The default safety settings to use for all generations.
|
||||
|
||||
For example:
|
||||
|
||||
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
|
||||
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -237,9 +262,13 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
"""Validate that the python package exists in environment."""
|
||||
tuned_model_name = values.get("tuned_model_name")
|
||||
model_name = values["model_name"]
|
||||
safety_settings = values["safety_settings"]
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._init_vertexai(values)
|
||||
|
||||
if safety_settings and (not is_gemini or tuned_model_name):
|
||||
raise ValueError("Safety settings are only supported for Gemini models")
|
||||
|
||||
if is_codey_model(model_name):
|
||||
model_cls = CodeGenerationModel
|
||||
preview_model_cls = PreviewCodeGenerationModel
|
||||
@ -257,8 +286,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
)
|
||||
else:
|
||||
if is_gemini:
|
||||
values["client"] = model_cls(model_name=model_name)
|
||||
values["client_preview"] = preview_model_cls(model_name=model_name)
|
||||
values["client"] = model_cls(
|
||||
model_name=model_name, safety_settings=safety_settings
|
||||
)
|
||||
values["client_preview"] = preview_model_cls(
|
||||
model_name=model_name, safety_settings=safety_settings
|
||||
)
|
||||
else:
|
||||
values["client"] = model_cls.from_pretrained(model_name)
|
||||
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
|
||||
@ -285,14 +318,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
self, response: TextGenerationResponse
|
||||
) -> GenerationChunk:
|
||||
"""Converts a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
generation_info = get_generation_info(response, self._is_gemini_model)
|
||||
|
||||
return GenerationChunk(
|
||||
text=response.text
|
||||
if hasattr(response, "text")
|
||||
else "", # might not exist if blocked
|
||||
generation_info=generation_info,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
94
libs/partners/google-vertexai/poetry.lock
generated
94
libs/partners/google-vertexai/poetry.lock
generated
@ -504,13 +504,13 @@ uritemplate = ">=3.0.1,<5"
|
||||
|
||||
[[package]]
|
||||
name = "google-auth"
|
||||
version = "2.26.1"
|
||||
version = "2.26.2"
|
||||
description = "Google Authentication Library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google-auth-2.26.1.tar.gz", hash = "sha256:54385acca5c0fbdda510cd8585ba6f3fcb06eeecf8a6ecca39d3ee148b092590"},
|
||||
{file = "google_auth-2.26.1-py2.py3-none-any.whl", hash = "sha256:2c8b55e3e564f298122a02ab7b97458ccfcc5617840beb5d0ac757ada92c9780"},
|
||||
{file = "google-auth-2.26.2.tar.gz", hash = "sha256:97327dbbf58cccb58fc5a1712bba403ae76668e64814eb30f7316f7e27126b81"},
|
||||
{file = "google_auth-2.26.2-py2.py3-none-any.whl", hash = "sha256:3f445c8ce9b61ed6459aad86d8ccdba4a9afed841b2d1451a11ef4db08957424"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -582,13 +582,13 @@ xai = ["tensorflow (>=2.3.0,<3.0.0dev)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-cloud-bigquery"
|
||||
version = "3.14.1"
|
||||
version = "3.16.0"
|
||||
description = "Google BigQuery API client library"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "google-cloud-bigquery-3.14.1.tar.gz", hash = "sha256:aa15bd86f79ea76824c7d710f5ae532323c4b3ba01ef4abff42d4ee7a2e9b142"},
|
||||
{file = "google_cloud_bigquery-3.14.1-py2.py3-none-any.whl", hash = "sha256:a8ded18455da71508db222b7c06197bc12b6dbc6ed5b0b64e7007b76d7016957"},
|
||||
{file = "google-cloud-bigquery-3.16.0.tar.gz", hash = "sha256:1d6abf4b1d740df17cb43a078789872af8059a0b1dd999f32ea69ebc6f7ba7ef"},
|
||||
{file = "google_cloud_bigquery-3.16.0-py2.py3-none-any.whl", hash = "sha256:8bac7754f92bf87ee81f38deabb7554d82bb9591fbe06a5c82f33e46e5a482f9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1110,13 +1110,13 @@ url = "../../core"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.0.77"
|
||||
version = "0.0.81"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "langsmith-0.0.77-py3-none-any.whl", hash = "sha256:750c0aa9177240c64e131d831e009ed08dd59038f7cabbd0bbcf62ccb7c8dcac"},
|
||||
{file = "langsmith-0.0.77.tar.gz", hash = "sha256:c4c8d3a96ad8671a41064f3ccc673e2e22a4153e823b19f915c9c9b8a4f33a2c"},
|
||||
{file = "langsmith-0.0.81-py3-none-any.whl", hash = "sha256:eb816ad456776ec4c6005ddce8a4c315a1a582ed4d079979888e9f8a1db209b3"},
|
||||
{file = "langsmith-0.0.81.tar.gz", hash = "sha256:5838e5a4bb1939e9794eb3f802f7c390247a847bd603e31442be5be00068e504"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1410,22 +1410,22 @@ testing = ["google-api-core[grpc] (>=1.31.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "4.25.1"
|
||||
version = "4.25.2"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "protobuf-4.25.1-cp310-abi3-win32.whl", hash = "sha256:193f50a6ab78a970c9b4f148e7c750cfde64f59815e86f686c22e26b4fe01ce7"},
|
||||
{file = "protobuf-4.25.1-cp310-abi3-win_amd64.whl", hash = "sha256:3497c1af9f2526962f09329fd61a36566305e6c72da2590ae0d7d1322818843b"},
|
||||
{file = "protobuf-4.25.1-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:0bf384e75b92c42830c0a679b0cd4d6e2b36ae0cf3dbb1e1dfdda48a244f4bcd"},
|
||||
{file = "protobuf-4.25.1-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:0f881b589ff449bf0b931a711926e9ddaad3b35089cc039ce1af50b21a4ae8cb"},
|
||||
{file = "protobuf-4.25.1-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:ca37bf6a6d0046272c152eea90d2e4ef34593aaa32e8873fc14c16440f22d4b7"},
|
||||
{file = "protobuf-4.25.1-cp38-cp38-win32.whl", hash = "sha256:abc0525ae2689a8000837729eef7883b9391cd6aa7950249dcf5a4ede230d5dd"},
|
||||
{file = "protobuf-4.25.1-cp38-cp38-win_amd64.whl", hash = "sha256:1484f9e692091450e7edf418c939e15bfc8fc68856e36ce399aed6889dae8bb0"},
|
||||
{file = "protobuf-4.25.1-cp39-cp39-win32.whl", hash = "sha256:8bdbeaddaac52d15c6dce38c71b03038ef7772b977847eb6d374fc86636fa510"},
|
||||
{file = "protobuf-4.25.1-cp39-cp39-win_amd64.whl", hash = "sha256:becc576b7e6b553d22cbdf418686ee4daa443d7217999125c045ad56322dda10"},
|
||||
{file = "protobuf-4.25.1-py3-none-any.whl", hash = "sha256:a19731d5e83ae4737bb2a089605e636077ac001d18781b3cf489b9546c7c80d6"},
|
||||
{file = "protobuf-4.25.1.tar.gz", hash = "sha256:57d65074b4f5baa4ab5da1605c02be90ac20c8b40fb137d6a8df9f416b0d0ce2"},
|
||||
{file = "protobuf-4.25.2-cp310-abi3-win32.whl", hash = "sha256:b50c949608682b12efb0b2717f53256f03636af5f60ac0c1d900df6213910fd6"},
|
||||
{file = "protobuf-4.25.2-cp310-abi3-win_amd64.whl", hash = "sha256:8f62574857ee1de9f770baf04dde4165e30b15ad97ba03ceac65f760ff018ac9"},
|
||||
{file = "protobuf-4.25.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:2db9f8fa64fbdcdc93767d3cf81e0f2aef176284071507e3ede160811502fd3d"},
|
||||
{file = "protobuf-4.25.2-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:10894a2885b7175d3984f2be8d9850712c57d5e7587a2410720af8be56cdaf62"},
|
||||
{file = "protobuf-4.25.2-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:fc381d1dd0516343f1440019cedf08a7405f791cd49eef4ae1ea06520bc1c020"},
|
||||
{file = "protobuf-4.25.2-cp38-cp38-win32.whl", hash = "sha256:33a1aeef4b1927431d1be780e87b641e322b88d654203a9e9d93f218ee359e61"},
|
||||
{file = "protobuf-4.25.2-cp38-cp38-win_amd64.whl", hash = "sha256:47f3de503fe7c1245f6f03bea7e8d3ec11c6c4a2ea9ef910e3221c8a15516d62"},
|
||||
{file = "protobuf-4.25.2-cp39-cp39-win32.whl", hash = "sha256:5e5c933b4c30a988b52e0b7c02641760a5ba046edc5e43d3b94a74c9fc57c1b3"},
|
||||
{file = "protobuf-4.25.2-cp39-cp39-win_amd64.whl", hash = "sha256:d66a769b8d687df9024f2985d5137a337f957a0916cf5464d1513eee96a63ff0"},
|
||||
{file = "protobuf-4.25.2-py3-none-any.whl", hash = "sha256:a8b7a98d4ce823303145bf3c1a8bdb0f2f4642a414b196f04ad9853ed0c8f830"},
|
||||
{file = "protobuf-4.25.2.tar.gz", hash = "sha256:fe599e175cb347efc8ee524bcd4b902d11f7262c0e569ececcb89995c15f0a5e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1775,28 +1775,28 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.1.11"
|
||||
version = "0.1.13"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.1.11-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:a7f772696b4cdc0a3b2e527fc3c7ccc41cdcb98f5c80fdd4f2b8c50eb1458196"},
|
||||
{file = "ruff-0.1.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:934832f6ed9b34a7d5feea58972635c2039c7a3b434fe5ba2ce015064cb6e955"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea0d3e950e394c4b332bcdd112aa566010a9f9c95814844a7468325290aabfd9"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bd4025b9c5b429a48280785a2b71d479798a69f5c2919e7d274c5f4b32c3607"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1ad00662305dcb1e987f5ec214d31f7d6a062cae3e74c1cbccef15afd96611d"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4b077ce83f47dd6bea1991af08b140e8b8339f0ba8cb9b7a484c30ebab18a23f"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4a88efecec23c37b11076fe676e15c6cdb1271a38f2b415e381e87fe4517f18"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b25093dad3b055667730a9b491129c42d45e11cdb7043b702e97125bcec48a1"},
|
||||
{file = "ruff-0.1.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:231d8fb11b2cc7c0366a326a66dafc6ad449d7fcdbc268497ee47e1334f66f77"},
|
||||
{file = "ruff-0.1.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:09c415716884950080921dd6237767e52e227e397e2008e2bed410117679975b"},
|
||||
{file = "ruff-0.1.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0f58948c6d212a6b8d41cd59e349751018797ce1727f961c2fa755ad6208ba45"},
|
||||
{file = "ruff-0.1.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:190a566c8f766c37074d99640cd9ca3da11d8deae2deae7c9505e68a4a30f740"},
|
||||
{file = "ruff-0.1.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6464289bd67b2344d2a5d9158d5eb81025258f169e69a46b741b396ffb0cda95"},
|
||||
{file = "ruff-0.1.11-py3-none-win32.whl", hash = "sha256:9b8f397902f92bc2e70fb6bebfa2139008dc72ae5177e66c383fa5426cb0bf2c"},
|
||||
{file = "ruff-0.1.11-py3-none-win_amd64.whl", hash = "sha256:eb85ee287b11f901037a6683b2374bb0ec82928c5cbc984f575d0437979c521a"},
|
||||
{file = "ruff-0.1.11-py3-none-win_arm64.whl", hash = "sha256:97ce4d752f964ba559c7023a86e5f8e97f026d511e48013987623915431c7ea9"},
|
||||
{file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"},
|
||||
{file = "ruff-0.1.13-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:e3fd36e0d48aeac672aa850045e784673449ce619afc12823ea7868fcc41d8ba"},
|
||||
{file = "ruff-0.1.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9fb6b3b86450d4ec6a6732f9f60c4406061b6851c4b29f944f8c9d91c3611c7a"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b13ba5d7156daaf3fd08b6b993360a96060500aca7e307d95ecbc5bb47a69296"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9ebb40442f7b531e136d334ef0851412410061e65d61ca8ce90d894a094feb22"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226b517f42d59a543d6383cfe03cccf0091e3e0ed1b856c6824be03d2a75d3b6"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5f0312ba1061e9b8c724e9a702d3c8621e3c6e6c2c9bd862550ab2951ac75c16"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2f59bcf5217c661254bd6bc42d65a6fd1a8b80c48763cb5c2293295babd945dd"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e6894b00495e00c27b6ba61af1fc666f17de6140345e5ef27dd6e08fb987259d"},
|
||||
{file = "ruff-0.1.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a1600942485c6e66119da294c6294856b5c86fd6df591ce293e4a4cc8e72989"},
|
||||
{file = "ruff-0.1.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ee3febce7863e231a467f90e681d3d89210b900d49ce88723ce052c8761be8c7"},
|
||||
{file = "ruff-0.1.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dcaab50e278ff497ee4d1fe69b29ca0a9a47cd954bb17963628fa417933c6eb1"},
|
||||
{file = "ruff-0.1.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f57de973de4edef3ad3044d6a50c02ad9fc2dff0d88587f25f1a48e3f72edf5e"},
|
||||
{file = "ruff-0.1.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7a36fa90eb12208272a858475ec43ac811ac37e91ef868759770b71bdabe27b6"},
|
||||
{file = "ruff-0.1.13-py3-none-win32.whl", hash = "sha256:a623349a505ff768dad6bd57087e2461be8db58305ebd5577bd0e98631f9ae69"},
|
||||
{file = "ruff-0.1.13-py3-none-win_amd64.whl", hash = "sha256:f988746e3c3982bea7f824c8fa318ce7f538c4dfefec99cd09c8770bd33e6539"},
|
||||
{file = "ruff-0.1.13-py3-none-win_arm64.whl", hash = "sha256:6bbbc3042075871ec17f28864808540a26f0f79a4478c357d3e3d2284e832998"},
|
||||
{file = "ruff-0.1.13.tar.gz", hash = "sha256:e261f1baed6291f434ffb1d5c6bd8051d1c2a26958072d38dfbec39b3dda7352"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2033,24 +2033,24 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "types-protobuf"
|
||||
version = "4.24.0.4"
|
||||
version = "4.24.0.20240106"
|
||||
description = "Typing stubs for protobuf"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "types-protobuf-4.24.0.4.tar.gz", hash = "sha256:57ab42cb171dfdba2c74bb5b50c250478538cc3c5ed95b8b368929ad0c9f90a5"},
|
||||
{file = "types_protobuf-4.24.0.4-py3-none-any.whl", hash = "sha256:131ab7d0cbc9e444bc89c994141327dcce7bcaeded72b1acb72a94827eb9c7af"},
|
||||
{file = "types-protobuf-4.24.0.20240106.tar.gz", hash = "sha256:024f034f3b5e2bb2bbff55ebc4d591ed0d2280d90faceedcb148b9e714a3f3ee"},
|
||||
{file = "types_protobuf-4.24.0.20240106-py3-none-any.whl", hash = "sha256:0612ef3156bd80567460a15ac7c109b313f6022f1fee04b4d922ab2789baab79"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.31.0.20231231"
|
||||
version = "2.31.0.20240106"
|
||||
description = "Typing stubs for requests"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "types-requests-2.31.0.20231231.tar.gz", hash = "sha256:0f8c0c9764773384122813548d9eea92a5c4e1f33ed54556b508968ec5065cee"},
|
||||
{file = "types_requests-2.31.0.20231231-py3-none-any.whl", hash = "sha256:2e2230c7bc8dd63fa3153c1c0ae335f8a368447f0582fc332f17d54f88e69027"},
|
||||
{file = "types-requests-2.31.0.20240106.tar.gz", hash = "sha256:0e1c731c17f33618ec58e022b614a1a2ecc25f7dc86800b36ef341380402c612"},
|
||||
{file = "types_requests-2.31.0.20240106-py3-none-any.whl", hash = "sha256:da997b3b6a72cc08d09f4dba9802fdbabc89104b35fe24ee588e674037689354"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Test ChatGoogleVertexAI chat model."""
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -6,7 +8,7 @@ from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||
|
||||
@ -60,7 +62,13 @@ async def test_vertexai_agenerate(model_name: str) -> None:
|
||||
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
|
||||
|
||||
sync_response = model.generate([[message]])
|
||||
assert response.generations[0][0] == sync_response.generations[0][0]
|
||||
sync_generation = cast(ChatGeneration, sync_response.generations[0][0])
|
||||
async_generation = cast(ChatGeneration, response.generations[0][0])
|
||||
|
||||
# assert some properties to make debugging easier
|
||||
assert sync_generation.message.content == async_generation.message.content
|
||||
assert sync_generation.generation_info == async_generation.generation_info
|
||||
assert sync_generation == async_generation
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
|
||||
|
@ -42,6 +42,7 @@ def test_vertex_call(model_name: str) -> None:
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
|
||||
def test_vertex_generate() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||
output = llm.generate(["Say foo:"])
|
||||
@ -50,6 +51,7 @@ def test_vertex_generate() -> None:
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
|
||||
def test_vertex_generate_code() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="code-bison@001")
|
||||
output = llm.generate(["generate a python method that says foo:"])
|
||||
@ -87,6 +89,7 @@ async def test_vertex_consistency() -> None:
|
||||
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||
|
||||
|
||||
@pytest.mark.skip("CI testing not set up")
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
@ -115,6 +118,7 @@ def test_model_garden(
|
||||
assert llm._llm_type == "vertexai_model_garden"
|
||||
|
||||
|
||||
@pytest.mark.skip("CI testing not set up")
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
@ -143,6 +147,7 @@ def test_model_garden_generate(
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
@pytest.mark.skip("CI testing not set up")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
|
@ -0,0 +1,97 @@
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory, VertexAI
|
||||
|
||||
SAFETY_SETTINGS = {
|
||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
|
||||
# below context and question are taken from one of opensource QA datasets
|
||||
BLOCKED_PROMPT = """
|
||||
You are agent designed to answer questions.
|
||||
You are given context in triple backticks.
|
||||
```
|
||||
The religion\'s failure to report abuse allegations to authorities has also been
|
||||
criticized. The Watch Tower Society\'s policy is that elders inform authorities when
|
||||
required by law to do so, but otherwise leave that action up to the victim and his
|
||||
or her family. The Australian Royal Commission into Institutional Responses to Child
|
||||
Sexual Abuse found that of 1006 alleged perpetrators of child sexual abuse
|
||||
identified by the Jehovah\'s Witnesses within their organization since 1950,
|
||||
"not one was reported by the church to secular authorities." William Bowen, a former
|
||||
Jehovah\'s Witness elder who established the Silentlambs organization to assist sex
|
||||
abuse victims within the religion, has claimed Witness leaders discourage followers
|
||||
from reporting incidents of sexual misconduct to authorities, and other critics claim
|
||||
the organization is reluctant to alert authorities in order to protect its "crime-free"
|
||||
reputation. In court cases in the United Kingdom and the United States the Watch Tower
|
||||
Society has been found to have been negligent in its failure to protect children from
|
||||
known sex offenders within the congregation and the Society has settled other child
|
||||
abuse lawsuits out of court, reportedly paying as much as $780,000 to one plaintiff
|
||||
without admitting wrongdoing.
|
||||
```
|
||||
Question: What have courts in both the UK and the US found the Watch Tower Society to
|
||||
have been for failing to protect children from sexual predators within the
|
||||
congregation ?
|
||||
Answer:
|
||||
"""
|
||||
|
||||
|
||||
def test_gemini_safety_settings_generate() -> None:
|
||||
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
|
||||
output = llm.generate(["What do you think about child abuse:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
generation_info = output.generations[0][0].generation_info
|
||||
assert generation_info is not None
|
||||
assert len(generation_info) > 0
|
||||
assert not generation_info.get("is_blocked")
|
||||
|
||||
blocked_output = llm.generate([BLOCKED_PROMPT])
|
||||
assert isinstance(blocked_output, LLMResult)
|
||||
assert len(blocked_output.generations) == 1
|
||||
assert len(blocked_output.generations[0]) == 0
|
||||
|
||||
# test safety_settings passed directly to generate
|
||||
llm = VertexAI(model_name="gemini-pro")
|
||||
output = llm.generate(
|
||||
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
|
||||
)
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
generation_info = output.generations[0][0].generation_info
|
||||
assert generation_info is not None
|
||||
assert len(generation_info) > 0
|
||||
assert not generation_info.get("is_blocked")
|
||||
|
||||
|
||||
async def test_gemini_safety_settings_agenerate() -> None:
|
||||
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
|
||||
output = await llm.agenerate(["What do you think about child abuse:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
generation_info = output.generations[0][0].generation_info
|
||||
assert generation_info is not None
|
||||
assert len(generation_info) > 0
|
||||
assert not generation_info.get("is_blocked")
|
||||
|
||||
blocked_output = await llm.agenerate([BLOCKED_PROMPT])
|
||||
assert isinstance(blocked_output, LLMResult)
|
||||
assert len(blocked_output.generations) == 1
|
||||
# assert len(blocked_output.generations[0][0].generation_info) > 0
|
||||
# assert blocked_output.generations[0][0].generation_info.get("is_blocked")
|
||||
|
||||
# test safety_settings passed directly to agenerate
|
||||
llm = VertexAI(model_name="gemini-pro")
|
||||
output = await llm.agenerate(
|
||||
["What do you think about child abuse:"], safety_settings=SAFETY_SETTINGS
|
||||
)
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
generation_info = output.generations[0][0].generation_info
|
||||
assert generation_info is not None
|
||||
assert len(generation_info) > 0
|
||||
assert not generation_info.get("is_blocked")
|
@ -1,6 +1,13 @@
|
||||
from langchain_google_vertexai import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatVertexAI", "VertexAIEmbeddings", "VertexAI", "VertexAIModelGarden"]
|
||||
EXPECTED_ALL = [
|
||||
"ChatVertexAI",
|
||||
"VertexAIEmbeddings",
|
||||
"VertexAI",
|
||||
"VertexAIModelGarden",
|
||||
"HarmBlockThreshold",
|
||||
"HarmCategory",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user