feat(langchain): Add support to google_genai provider in init_embeddings (#34388)

This commit is contained in:
Christophe Lamarche
2025-12-19 14:04:13 -05:00
committed by GitHub
parent 7902fa3238
commit f752c1a07f
4 changed files with 39 additions and 24 deletions

View File

@@ -10,6 +10,7 @@ _SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_genai": "langchain_google_genai",
"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
@@ -207,6 +208,10 @@ def init_embeddings(
from langchain_openai import AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
if provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
if provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

View File

@@ -9,19 +9,20 @@ from langchain.embeddings.base import (
)
def test_parse_model_string() -> None:
@pytest.mark.parametrize(
("model_string", "expected_provider", "expected_model"),
[
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
],
)
def test_parse_model_string(model_string: str, expected_provider: str, expected_model: str) -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
assert _parse_model_string(model_string) == (
expected_provider,
expected_model,
)