diff --git a/libs/partners/google-genai/langchain_google_genai/__init__.py b/libs/partners/google-genai/langchain_google_genai/__init__.py index 505e121d1e2..187f7e3e036 100644 --- a/libs/partners/google-genai/langchain_google_genai/__init__.py +++ b/libs/partners/google-genai/langchain_google_genai/__init__.py @@ -54,6 +54,8 @@ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") embeddings.embed_query("hello, world!") ``` """ # noqa: E501 + +from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory from langchain_google_genai.chat_models import ChatGoogleGenerativeAI from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings from langchain_google_genai.llms import GoogleGenerativeAI @@ -62,4 +64,6 @@ __all__ = [ "ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings", "GoogleGenerativeAI", + "HarmBlockThreshold", + "HarmCategory", ] diff --git a/libs/partners/google-genai/langchain_google_genai/_enums.py b/libs/partners/google-genai/langchain_google_genai/_enums.py new file mode 100644 index 00000000000..b2a1de2614e --- /dev/null +++ b/libs/partners/google-genai/langchain_google_genai/_enums.py @@ -0,0 +1,6 @@ +from google.generativeai.types.safety_types import ( # type: ignore + HarmBlockThreshold, + HarmCategory, +) + +__all__ = ["HarmBlockThreshold", "HarmCategory"] diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index 496f54add55..4f0e9f0a223 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -517,6 +517,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): "temperature": self.temperature, "top_k": self.top_k, "n": self.n, + "safety_settings": self.safety_settings, } def _prepare_params( @@ -549,7 +550,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): params, chat, message = self._prepare_chat( messages, stop=stop, - functions=kwargs.get("functions"), + **kwargs, ) response: genai.types.GenerateContentResponse = _chat_with_retry( content=message, @@ -568,7 +569,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): params, chat, message = self._prepare_chat( messages, stop=stop, - functions=kwargs.get("functions"), + **kwargs, ) response: genai.types.GenerateContentResponse = await _achat_with_retry( content=message, @@ -587,7 +588,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): params, chat, message = self._prepare_chat( messages, stop=stop, - functions=kwargs.get("functions"), + **kwargs, ) response: genai.types.GenerateContentResponse = _chat_with_retry( content=message, @@ -613,7 +614,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): params, chat, message = self._prepare_chat( messages, stop=stop, - functions=kwargs.get("functions"), + **kwargs, ) async for chunk in await _achat_with_retry( content=message, @@ -636,9 +637,14 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]: client = self.client functions = kwargs.pop("functions", None) - if functions: - tools = convert_to_genai_function_declarations(functions) - client = genai.GenerativeModel(model_name=self.model, tools=tools) + safety_settings = kwargs.pop("safety_settings", self.safety_settings) + if functions or safety_settings: + tools = ( + convert_to_genai_function_declarations(functions) if functions else None + ) + client = genai.GenerativeModel( + model_name=self.model, tools=tools, safety_settings=safety_settings + ) params = self._prepare_params(stop, **kwargs) history = _parse_chat_history( diff --git a/libs/partners/google-genai/langchain_google_genai/llms.py b/libs/partners/google-genai/langchain_google_genai/llms.py index d6eb16d9819..9b483873146 100644 --- a/libs/partners/google-genai/langchain_google_genai/llms.py +++ b/libs/partners/google-genai/langchain_google_genai/llms.py @@ -15,6 +15,11 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import get_from_dict_or_env +from langchain_google_genai._enums import ( + HarmBlockThreshold, + HarmCategory, +) + class GoogleModelFamily(str, Enum): GEMINI = auto() @@ -77,7 +82,10 @@ def _completion_with_retry( try: if is_gemini: return llm.client.generate_content( - contents=prompt, stream=stream, generation_config=generation_config + contents=prompt, + stream=stream, + generation_config=generation_config, + safety_settings=kwargs.pop("safety_settings", None), ) return llm.client.generate_text(prompt=prompt, **kwargs) except google.api_core.exceptions.FailedPrecondition as exc: @@ -143,6 +151,22 @@ Supported examples: description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", ) + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None + """The default safety settings to use for all generations. + + For example: + + from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory + + safety_settings = { + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + } + """ # noqa: E501 + @property def lc_secrets(self) -> Dict[str, str]: return {"google_api_key": "GOOGLE_API_KEY"} @@ -184,6 +208,8 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): ) model_name = values["model"] + safety_settings = values["safety_settings"] + if isinstance(google_api_key, SecretStr): google_api_key = google_api_key.get_secret_value() @@ -193,8 +219,15 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): client_options=values.get("client_options"), ) + if safety_settings and ( + not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI + ): + raise ValueError("Safety settings are only supported for Gemini models") + if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI: - values["client"] = genai.GenerativeModel(model_name=model_name) + values["client"] = genai.GenerativeModel( + model_name=model_name, safety_settings=safety_settings + ) else: values["client"] = genai @@ -237,6 +270,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): is_gemini=True, run_manager=run_manager, generation_config=generation_config, + safety_settings=kwargs.pop("safety_settings", None), ) candidates = [ "".join([p.text for p in c.content.parts]) for c in res.candidates @@ -278,6 +312,7 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): is_gemini=True, run_manager=run_manager, generation_config=generation_config, + safety_settings=kwargs.pop("safety_settings", None), **kwargs, ): chunk = GenerationChunk(text=stream_resp.text) diff --git a/libs/partners/google-genai/poetry.lock b/libs/partners/google-genai/poetry.lock index 60a87e151ec..15ee5ceb50a 100644 --- a/libs/partners/google-genai/poetry.lock +++ b/libs/partners/google-genai/poetry.lock @@ -1146,13 +1146,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.1" +version = "4.66.2" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, - {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, + {file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"}, + {file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"}, ] [package.dependencies] diff --git a/libs/partners/google-genai/tests/integration_tests/test_chat_models.py b/libs/partners/google-genai/tests/integration_tests/test_chat_models.py index 91e6aab20fb..26e85aa10e1 100644 --- a/libs/partners/google-genai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-genai/tests/integration_tests/test_chat_models.py @@ -1,12 +1,16 @@ """Test ChatGoogleGenerativeAI chat model.""" +from typing import Generator + import pytest from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_google_genai.chat_models import ( +from langchain_google_genai import ( ChatGoogleGenerativeAI, - ChatGoogleGenerativeAIError, + HarmBlockThreshold, + HarmCategory, ) +from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError _MODEL = "gemini-pro" # TODO: Use nano when it's available. _VISION_MODEL = "gemini-pro-vision" @@ -193,3 +197,32 @@ def test_generativeai_get_num_tokens_gemini() -> None: llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro") output = llm.get_num_tokens("How are you?") assert output == 4 + + +def test_safety_settings_gemini() -> None: + safety_settings = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + # test with safety filters on bind + llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro").bind( + safety_settings=safety_settings + ) + output = llm.invoke("how to make a bomb?") + assert isinstance(output, AIMessage) + assert len(output.content) > 0 + + # test direct to stream + streamed_messages = [] + output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings) + assert isinstance(output_stream, Generator) + for message in output_stream: + streamed_messages.append(message) + assert len(streamed_messages) > 0 + + # test as init param + llm = ChatGoogleGenerativeAI( + temperature=0, model="gemini-pro", safety_settings=safety_settings + ) + out2 = llm.invoke("how to make a bomb") + assert isinstance(out2, AIMessage) + assert len(out2.content) > 0 diff --git a/libs/partners/google-genai/tests/integration_tests/test_llms.py b/libs/partners/google-genai/tests/integration_tests/test_llms.py index 9bdf49dda84..b761e148041 100644 --- a/libs/partners/google-genai/tests/integration_tests/test_llms.py +++ b/libs/partners/google-genai/tests/integration_tests/test_llms.py @@ -4,10 +4,12 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to valid API key. """ +from typing import Generator + import pytest from langchain_core.outputs import LLMResult -from langchain_google_genai.llms import GoogleGenerativeAI +from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory model_names = ["models/text-bison-001", "gemini-pro"] @@ -66,3 +68,39 @@ def test_generativeai_get_num_tokens_gemini() -> None: llm = GoogleGenerativeAI(temperature=0, model="gemini-pro") output = llm.get_num_tokens("How are you?") assert output == 4 + + +def test_safety_settings_gemini() -> None: + # test with blocked prompt + llm = GoogleGenerativeAI(temperature=0, model="gemini-pro") + output = llm.generate(prompts=["how to make a bomb?"]) + assert isinstance(output, LLMResult) + assert len(output.generations[0]) == 0 + + # safety filters + safety_settings = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + + # test with safety filters directly to generate + output = llm.generate(["how to make a bomb?"], safety_settings=safety_settings) + assert isinstance(output, LLMResult) + assert len(output.generations[0]) > 0 + + # test with safety filters directly to stream + streamed_messages = [] + output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings) + assert isinstance(output_stream, Generator) + for message in output_stream: + streamed_messages.append(message) + assert len(streamed_messages) > 0 + + # test with safety filters on instantiation + llm = GoogleGenerativeAI( + model="gemini-pro", + safety_settings=safety_settings, + temperature=0, + ) + output = llm.generate(prompts=["how to make a bomb?"]) + assert isinstance(output, LLMResult) + assert len(output.generations[0]) > 0 diff --git a/libs/partners/google-genai/tests/unit_tests/test_imports.py b/libs/partners/google-genai/tests/unit_tests/test_imports.py index e189a9fc775..8c90cb2bd10 100644 --- a/libs/partners/google-genai/tests/unit_tests/test_imports.py +++ b/libs/partners/google-genai/tests/unit_tests/test_imports.py @@ -4,6 +4,8 @@ EXPECTED_ALL = [ "ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings", "GoogleGenerativeAI", + "HarmBlockThreshold", + "HarmCategory", ]