feat(langchain): Add support to google_genai provider in init_embeddings (#34388)

This commit is contained in:
Christophe Lamarche
2025-12-19 14:04:13 -05:00
committed by GitHub
parent 7902fa3238
commit f752c1a07f
4 changed files with 39 additions and 24 deletions

View File

@@ -9,6 +9,7 @@ _SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai", "azure_openai": "langchain_openai",
"bedrock": "langchain_aws", "bedrock": "langchain_aws",
"cohere": "langchain_cohere", "cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
"google_vertexai": "langchain_google_vertexai", "google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface", "huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai", "mistralai": "langchain_mistralai",
@@ -155,6 +156,7 @@ def init_embeddings(
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai) - `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws) - `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere) - `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google) - `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface) - `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai) - `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
@@ -207,6 +209,10 @@ def init_embeddings(
from langchain_openai import AzureOpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs) return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai": if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_vertexai import VertexAIEmbeddings

View File

@@ -9,19 +9,22 @@ from langchain_classic.embeddings.base import (
) )
def test_parse_model_string() -> None: @pytest.mark.parametrize(
("model_string", "expected_provider", "expected_model"),
[
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
],
)
def test_parse_model_string(
model_string: str, expected_provider: str, expected_model: str
) -> None:
"""Test parsing model strings into provider and model components.""" """Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == ( assert _parse_model_string(model_string) == (
"openai", expected_provider,
"text-embedding-3-small", expected_model,
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
) )

View File

@@ -10,6 +10,7 @@ _SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai", "azure_openai": "langchain_openai",
"bedrock": "langchain_aws", "bedrock": "langchain_aws",
"cohere": "langchain_cohere", "cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
"google_vertexai": "langchain_google_vertexai", "google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface", "huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai", "mistralai": "langchain_mistralai",
@@ -207,6 +208,10 @@ def init_embeddings(
from langchain_openai import AzureOpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs) return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai": if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_vertexai import VertexAIEmbeddings

View File

@@ -9,19 +9,20 @@ from langchain.embeddings.base import (
) )
def test_parse_model_string() -> None: @pytest.mark.parametrize(
("model_string", "expected_provider", "expected_model"),
[
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
],
)
def test_parse_model_string(model_string: str, expected_provider: str, expected_model: str) -> None:
"""Test parsing model strings into provider and model components.""" """Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == ( assert _parse_model_string(model_string) == (
"openai", expected_provider,
"text-embedding-3-small", expected_model,
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
) )