google-vertexai[minor]: added safety_settings property to gemini wrapper (#15344)

**Description:** Gemini model has quite annoying default safety_settings
settings. In addition, current VertexAI class doesn't provide a property
to override such settings.
So, this PR aims to 
 - add safety_settings property to VertexAI
- fix issue with incorrect LLM output parsing when LLM responds with
appropriate 'blocked' response
- fix issue with incorrect parsing LLM output when Gemini API blocks
prompt itself as inappropriate
- add safety_settings related tests

I'm not enough familiar with langchain code base and guidelines. So, any
comments and/or suggestions are very welcome.
 
**Issue:** it will likely fix #14841

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Eugene Zapolsky
2024-01-18 18:54:30 +02:00
committed by GitHub
parent ecd4f0a7ec
commit 6b9e3ed9e9
11 changed files with 448 additions and 106 deletions

View File

@@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {
"tags": []
},
@@ -44,10 +44,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
"^C\n",
"\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n",
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
]
}
],
@@ -57,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -67,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -76,7 +75,7 @@
"AIMessage(content=\" J'aime la programmation.\")"
]
},
"execution_count": 2,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -101,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -110,7 +109,7 @@
"AIMessage(content=' プログラミングが大好きです')"
]
},
"execution_count": 3,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -154,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"tags": []
},
@@ -165,27 +164,51 @@
"text": [
" ```python\n",
"def is_prime(n):\n",
" if n <= 1:\n",
" return False\n",
" for i in range(2, n):\n",
" if n % i == 0:\n",
" return False\n",
" return True\n",
" \"\"\"\n",
" Check if a number is prime.\n",
"\n",
" Args:\n",
" n: The number to check.\n",
"\n",
" Returns:\n",
" True if n is prime, False otherwise.\n",
" \"\"\"\n",
"\n",
" # If n is 1, it is not prime.\n",
" if n == 1:\n",
" return False\n",
"\n",
" # Iterate over all numbers from 2 to the square root of n.\n",
" for i in range(2, int(n ** 0.5) + 1):\n",
" # If n is divisible by any number from 2 to its square root, it is not prime.\n",
" if n % i == 0:\n",
" return False\n",
"\n",
" # If n is divisible by no number from 2 to its square root, it is prime.\n",
" return True\n",
"\n",
"\n",
"def find_prime_numbers(n):\n",
" prime_numbers = []\n",
" for i in range(2, n + 1):\n",
" if is_prime(i):\n",
" prime_numbers.append(i)\n",
" return prime_numbers\n",
" \"\"\"\n",
" Find all prime numbers up to a given number.\n",
"\n",
"print(find_prime_numbers(100))\n",
"```\n",
" Args:\n",
" n: The upper bound for the prime numbers to find.\n",
"\n",
"Output:\n",
" Returns:\n",
" A list of all prime numbers up to n.\n",
" \"\"\"\n",
"\n",
"```\n",
"[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]\n",
" # Create a list of all numbers from 2 to n.\n",
" numbers = list(range(2, n + 1))\n",
"\n",
" # Iterate over the list of numbers and remove any that are not prime.\n",
" for number in numbers:\n",
" if not is_prime(number):\n",
" numbers.remove(number)\n",
"\n",
" # Return the list of prime numbers.\n",
" return numbers\n",
"```\n"
]
}
@@ -199,6 +222,102 @@
"print(message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full generation info\n",
"\n",
"We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just chat completions\n",
"\n",
"Note that the `generation_info` will be different depending if you're using a gemini model or not.\n",
"\n",
"### Gemini model\n",
"\n",
"`generation_info` will include:\n",
"\n",
"- `is_blocked`: whether generation was blocked or not\n",
"- `safety_ratings`: safety ratings' categories and probability labels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'is_blocked': False,\n",
" 'safety_ratings': [{'category': 'HARM_CATEGORY_HARASSMENT',\n",
" 'probability_label': 'NEGLIGIBLE'},\n",
" {'category': 'HARM_CATEGORY_HATE_SPEECH',\n",
" 'probability_label': 'NEGLIGIBLE'},\n",
" {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n",
" 'probability_label': 'NEGLIGIBLE'},\n",
" {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n",
" 'probability_label': 'NEGLIGIBLE'}]}\n"
]
}
],
"source": [
"from pprint import pprint\n",
"\n",
"from langchain_core.messages import HumanMessage\n",
"from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory\n",
"\n",
"human = \"Translate this sentence from English to French. I love programming.\"\n",
"messages = [HumanMessage(content=human)]\n",
"\n",
"\n",
"chat = ChatVertexAI(\n",
" model_name=\"gemini-pro\",\n",
" safety_settings={\n",
" HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE\n",
" },\n",
")\n",
"\n",
"result = chat.generate([messages])\n",
"pprint(result.generations[0][0].generation_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Non-gemini model\n",
"\n",
"`generation_info` will include:\n",
"\n",
"- `is_blocked`: whether generation was blocked or not\n",
"- `safety_attributes`: a dictionary mapping safety attributes to their scores"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'is_blocked': False,\n",
" 'safety_attributes': {'Derogatory': 0.1,\n",
" 'Finance': 0.3,\n",
" 'Insult': 0.1,\n",
" 'Sexual': 0.1}}\n"
]
}
],
"source": [
"chat = ChatVertexAI() # default is `chat-bison`\n",
"\n",
"result = chat.generate([messages])\n",
"pprint(result.generations[0][0].generation_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -210,7 +329,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -224,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -268,7 +387,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{