mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 01:13:48 +00:00
google-genai[minor]: add safety settings (#16836)
Replace this entire comment with: - **Description:Expose safety_settings for Gemini integrations on google-generativeai - **Issue:NA, - **Dependencies:NA - **Twitter handle:@aditya_rane @lkuligin for review --------- Co-authored-by: adityarane@google.com <adityarane@google.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
584b647b96
commit
a23c719c8b
@ -54,6 +54,8 @@ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("hello, world!")
|
||||
```
|
||||
""" # noqa: E501
|
||||
|
||||
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||
@ -62,4 +64,6 @@ __all__ = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
"GoogleGenerativeAI",
|
||||
"HarmBlockThreshold",
|
||||
"HarmCategory",
|
||||
]
|
||||
|
@ -0,0 +1,6 @@
|
||||
from google.generativeai.types.safety_types import ( # type: ignore
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
)
|
||||
|
||||
__all__ = ["HarmBlockThreshold", "HarmCategory"]
|
@ -517,6 +517,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"n": self.n,
|
||||
"safety_settings": self.safety_settings,
|
||||
}
|
||||
|
||||
def _prepare_params(
|
||||
@ -549,7 +550,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
**kwargs,
|
||||
)
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
content=message,
|
||||
@ -568,7 +569,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
**kwargs,
|
||||
)
|
||||
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
||||
content=message,
|
||||
@ -587,7 +588,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
**kwargs,
|
||||
)
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
content=message,
|
||||
@ -613,7 +614,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
**kwargs,
|
||||
)
|
||||
async for chunk in await _achat_with_retry(
|
||||
content=message,
|
||||
@ -636,9 +637,14 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
||||
client = self.client
|
||||
functions = kwargs.pop("functions", None)
|
||||
if functions:
|
||||
tools = convert_to_genai_function_declarations(functions)
|
||||
client = genai.GenerativeModel(model_name=self.model, tools=tools)
|
||||
safety_settings = kwargs.pop("safety_settings", self.safety_settings)
|
||||
if functions or safety_settings:
|
||||
tools = (
|
||||
convert_to_genai_function_declarations(functions) if functions else None
|
||||
)
|
||||
client = genai.GenerativeModel(
|
||||
model_name=self.model, tools=tools, safety_settings=safety_settings
|
||||
)
|
||||
|
||||
params = self._prepare_params(stop, **kwargs)
|
||||
history = _parse_chat_history(
|
||||
|
@ -15,6 +15,11 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_google_genai._enums import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
)
|
||||
|
||||
|
||||
class GoogleModelFamily(str, Enum):
|
||||
GEMINI = auto()
|
||||
@ -77,7 +82,10 @@ def _completion_with_retry(
|
||||
try:
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
contents=prompt, stream=stream, generation_config=generation_config
|
||||
contents=prompt,
|
||||
stream=stream,
|
||||
generation_config=generation_config,
|
||||
safety_settings=kwargs.pop("safety_settings", None),
|
||||
)
|
||||
return llm.client.generate_text(prompt=prompt, **kwargs)
|
||||
except google.api_core.exceptions.FailedPrecondition as exc:
|
||||
@ -143,6 +151,22 @@ Supported examples:
|
||||
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
||||
)
|
||||
|
||||
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None
|
||||
"""The default safety settings to use for all generations.
|
||||
|
||||
For example:
|
||||
|
||||
from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
|
||||
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
@ -184,6 +208,8 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
||||
)
|
||||
model_name = values["model"]
|
||||
|
||||
safety_settings = values["safety_settings"]
|
||||
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
|
||||
@ -193,8 +219,15 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
||||
client_options=values.get("client_options"),
|
||||
)
|
||||
|
||||
if safety_settings and (
|
||||
not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI
|
||||
):
|
||||
raise ValueError("Safety settings are only supported for Gemini models")
|
||||
|
||||
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||
values["client"] = genai.GenerativeModel(
|
||||
model_name=model_name, safety_settings=safety_settings
|
||||
)
|
||||
else:
|
||||
values["client"] = genai
|
||||
|
||||
@ -237,6 +270,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
safety_settings=kwargs.pop("safety_settings", None),
|
||||
)
|
||||
candidates = [
|
||||
"".join([p.text for p in c.content.parts]) for c in res.candidates
|
||||
@ -278,6 +312,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
||||
is_gemini=True,
|
||||
run_manager=run_manager,
|
||||
generation_config=generation_config,
|
||||
safety_settings=kwargs.pop("safety_settings", None),
|
||||
**kwargs,
|
||||
):
|
||||
chunk = GenerationChunk(text=stream_resp.text)
|
||||
|
6
libs/partners/google-genai/poetry.lock
generated
6
libs/partners/google-genai/poetry.lock
generated
@ -1146,13 +1146,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.1"
|
||||
version = "4.66.2"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
|
||||
{file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
|
||||
{file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"},
|
||||
{file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1,12 +1,16 @@
|
||||
"""Test ChatGoogleGenerativeAI chat model."""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_google_genai.chat_models import (
|
||||
from langchain_google_genai import (
|
||||
ChatGoogleGenerativeAI,
|
||||
ChatGoogleGenerativeAIError,
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
)
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
|
||||
|
||||
_MODEL = "gemini-pro" # TODO: Use nano when it's available.
|
||||
_VISION_MODEL = "gemini-pro-vision"
|
||||
@ -193,3 +197,32 @@ def test_generativeai_get_num_tokens_gemini() -> None:
|
||||
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
def test_safety_settings_gemini() -> None:
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
# test with safety filters on bind
|
||||
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro").bind(
|
||||
safety_settings=safety_settings
|
||||
)
|
||||
output = llm.invoke("how to make a bomb?")
|
||||
assert isinstance(output, AIMessage)
|
||||
assert len(output.content) > 0
|
||||
|
||||
# test direct to stream
|
||||
streamed_messages = []
|
||||
output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings)
|
||||
assert isinstance(output_stream, Generator)
|
||||
for message in output_stream:
|
||||
streamed_messages.append(message)
|
||||
assert len(streamed_messages) > 0
|
||||
|
||||
# test as init param
|
||||
llm = ChatGoogleGenerativeAI(
|
||||
temperature=0, model="gemini-pro", safety_settings=safety_settings
|
||||
)
|
||||
out2 = llm.invoke("how to make a bomb")
|
||||
assert isinstance(out2, AIMessage)
|
||||
assert len(out2.content) > 0
|
||||
|
@ -4,10 +4,12 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
|
||||
valid API key.
|
||||
"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_google_genai.llms import GoogleGenerativeAI
|
||||
from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory
|
||||
|
||||
model_names = ["models/text-bison-001", "gemini-pro"]
|
||||
|
||||
@ -66,3 +68,39 @@ def test_generativeai_get_num_tokens_gemini() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
def test_safety_settings_gemini() -> None:
|
||||
# test with blocked prompt
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = llm.generate(prompts=["how to make a bomb?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations[0]) == 0
|
||||
|
||||
# safety filters
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
# test with safety filters directly to generate
|
||||
output = llm.generate(["how to make a bomb?"], safety_settings=safety_settings)
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations[0]) > 0
|
||||
|
||||
# test with safety filters directly to stream
|
||||
streamed_messages = []
|
||||
output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings)
|
||||
assert isinstance(output_stream, Generator)
|
||||
for message in output_stream:
|
||||
streamed_messages.append(message)
|
||||
assert len(streamed_messages) > 0
|
||||
|
||||
# test with safety filters on instantiation
|
||||
llm = GoogleGenerativeAI(
|
||||
model="gemini-pro",
|
||||
safety_settings=safety_settings,
|
||||
temperature=0,
|
||||
)
|
||||
output = llm.generate(prompts=["how to make a bomb?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations[0]) > 0
|
||||
|
@ -4,6 +4,8 @@ EXPECTED_ALL = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
"GoogleGenerativeAI",
|
||||
"HarmBlockThreshold",
|
||||
"HarmCategory",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user