From f752c1a07f389ab261a6867400e2359f9b4dcd6e Mon Sep 17 00:00:00 2001 From: Christophe Lamarche Date: Fri, 19 Dec 2025 14:04:13 -0500 Subject: [PATCH] feat(langchain): Add support to `google_genai` provider in `init_embeddings` (#34388) --- .../langchain_classic/embeddings/base.py | 6 +++++ .../tests/unit_tests/embeddings/test_base.py | 27 ++++++++++--------- .../langchain_v1/langchain/embeddings/base.py | 5 ++++ .../tests/unit_tests/embeddings/test_base.py | 25 ++++++++--------- 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/libs/langchain/langchain_classic/embeddings/base.py b/libs/langchain/langchain_classic/embeddings/base.py index 7fc32abc98e..e79aa0ff8b2 100644 --- a/libs/langchain/langchain_classic/embeddings/base.py +++ b/libs/langchain/langchain_classic/embeddings/base.py @@ -9,6 +9,7 @@ _SUPPORTED_PROVIDERS = { "azure_openai": "langchain_openai", "bedrock": "langchain_aws", "cohere": "langchain_cohere", + "google_genai": "langchain_google_genai", "google_vertexai": "langchain_google_vertexai", "huggingface": "langchain_huggingface", "mistralai": "langchain_mistralai", @@ -155,6 +156,7 @@ def init_embeddings( - `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai) - `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws) - `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere) + - `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google) - `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google) - `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface) - `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai) @@ -207,6 +209,10 @@ def init_embeddings( from langchain_openai import AzureOpenAIEmbeddings return AzureOpenAIEmbeddings(model=model_name, **kwargs) + if provider == "google_genai": + from langchain_google_genai import GoogleGenerativeAIEmbeddings + + return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs) if provider == "google_vertexai": from langchain_google_vertexai import VertexAIEmbeddings diff --git a/libs/langchain/tests/unit_tests/embeddings/test_base.py b/libs/langchain/tests/unit_tests/embeddings/test_base.py index ea7e9ebbc04..3e541e6a782 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_base.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_base.py @@ -9,19 +9,22 @@ from langchain_classic.embeddings.base import ( ) -def test_parse_model_string() -> None: +@pytest.mark.parametrize( + ("model_string", "expected_provider", "expected_model"), + [ + ("openai:text-embedding-3-small", "openai", "text-embedding-3-small"), + ("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"), + ("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"), + ("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"), + ], +) +def test_parse_model_string( + model_string: str, expected_provider: str, expected_model: str +) -> None: """Test parsing model strings into provider and model components.""" - assert _parse_model_string("openai:text-embedding-3-small") == ( - "openai", - "text-embedding-3-small", - ) - assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == ( - "bedrock", - "amazon.titan-embed-text-v1", - ) - assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == ( - "huggingface", - "BAAI/bge-base-en:v1.5", + assert _parse_model_string(model_string) == ( + expected_provider, + expected_model, ) diff --git a/libs/langchain_v1/langchain/embeddings/base.py b/libs/langchain_v1/langchain/embeddings/base.py index 97b5c62a2ef..e7c9063f7ad 100644 --- a/libs/langchain_v1/langchain/embeddings/base.py +++ b/libs/langchain_v1/langchain/embeddings/base.py @@ -10,6 +10,7 @@ _SUPPORTED_PROVIDERS = { "azure_openai": "langchain_openai", "bedrock": "langchain_aws", "cohere": "langchain_cohere", + "google_genai": "langchain_google_genai", "google_vertexai": "langchain_google_vertexai", "huggingface": "langchain_huggingface", "mistralai": "langchain_mistralai", @@ -207,6 +208,10 @@ def init_embeddings( from langchain_openai import AzureOpenAIEmbeddings return AzureOpenAIEmbeddings(model=model_name, **kwargs) + if provider == "google_genai": + from langchain_google_genai import GoogleGenerativeAIEmbeddings + + return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs) if provider == "google_vertexai": from langchain_google_vertexai import VertexAIEmbeddings diff --git a/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py b/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py index 24460855ce7..aa726d98f60 100644 --- a/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py +++ b/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py @@ -9,19 +9,20 @@ from langchain.embeddings.base import ( ) -def test_parse_model_string() -> None: +@pytest.mark.parametrize( + ("model_string", "expected_provider", "expected_model"), + [ + ("openai:text-embedding-3-small", "openai", "text-embedding-3-small"), + ("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"), + ("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"), + ("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"), + ], +) +def test_parse_model_string(model_string: str, expected_provider: str, expected_model: str) -> None: """Test parsing model strings into provider and model components.""" - assert _parse_model_string("openai:text-embedding-3-small") == ( - "openai", - "text-embedding-3-small", - ) - assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == ( - "bedrock", - "amazon.titan-embed-text-v1", - ) - assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == ( - "huggingface", - "BAAI/bge-base-en:v1.5", + assert _parse_model_string(model_string) == ( + expected_provider, + expected_model, )