mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(langchain): Add support to google_genai provider in init_embeddings (#34388)
This commit is contained in:
committed by
GitHub
parent
7902fa3238
commit
f752c1a07f
@@ -9,6 +9,7 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
"azure_openai": "langchain_openai",
|
"azure_openai": "langchain_openai",
|
||||||
"bedrock": "langchain_aws",
|
"bedrock": "langchain_aws",
|
||||||
"cohere": "langchain_cohere",
|
"cohere": "langchain_cohere",
|
||||||
|
"google_genai": "langchain_google_genai",
|
||||||
"google_vertexai": "langchain_google_vertexai",
|
"google_vertexai": "langchain_google_vertexai",
|
||||||
"huggingface": "langchain_huggingface",
|
"huggingface": "langchain_huggingface",
|
||||||
"mistralai": "langchain_mistralai",
|
"mistralai": "langchain_mistralai",
|
||||||
@@ -155,6 +156,7 @@ def init_embeddings(
|
|||||||
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
|
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
|
||||||
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
|
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
|
||||||
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
|
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
|
||||||
|
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
||||||
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
|
||||||
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
|
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
|
||||||
- `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
|
- `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
|
||||||
@@ -207,6 +209,10 @@ def init_embeddings(
|
|||||||
from langchain_openai import AzureOpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
|
||||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
|
if provider == "google_genai":
|
||||||
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||||
|
|
||||||
|
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
|
||||||
if provider == "google_vertexai":
|
if provider == "google_vertexai":
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
|
|
||||||
|
|||||||
@@ -9,19 +9,22 @@ from langchain_classic.embeddings.base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_model_string() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
("model_string", "expected_provider", "expected_model"),
|
||||||
|
[
|
||||||
|
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
|
||||||
|
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
|
||||||
|
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
|
||||||
|
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_model_string(
|
||||||
|
model_string: str, expected_provider: str, expected_model: str
|
||||||
|
) -> None:
|
||||||
"""Test parsing model strings into provider and model components."""
|
"""Test parsing model strings into provider and model components."""
|
||||||
assert _parse_model_string("openai:text-embedding-3-small") == (
|
assert _parse_model_string(model_string) == (
|
||||||
"openai",
|
expected_provider,
|
||||||
"text-embedding-3-small",
|
expected_model,
|
||||||
)
|
|
||||||
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
|
|
||||||
"bedrock",
|
|
||||||
"amazon.titan-embed-text-v1",
|
|
||||||
)
|
|
||||||
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
|
|
||||||
"huggingface",
|
|
||||||
"BAAI/bge-base-en:v1.5",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
"azure_openai": "langchain_openai",
|
"azure_openai": "langchain_openai",
|
||||||
"bedrock": "langchain_aws",
|
"bedrock": "langchain_aws",
|
||||||
"cohere": "langchain_cohere",
|
"cohere": "langchain_cohere",
|
||||||
|
"google_genai": "langchain_google_genai",
|
||||||
"google_vertexai": "langchain_google_vertexai",
|
"google_vertexai": "langchain_google_vertexai",
|
||||||
"huggingface": "langchain_huggingface",
|
"huggingface": "langchain_huggingface",
|
||||||
"mistralai": "langchain_mistralai",
|
"mistralai": "langchain_mistralai",
|
||||||
@@ -207,6 +208,10 @@ def init_embeddings(
|
|||||||
from langchain_openai import AzureOpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
|
||||||
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
return AzureOpenAIEmbeddings(model=model_name, **kwargs)
|
||||||
|
if provider == "google_genai":
|
||||||
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||||
|
|
||||||
|
return GoogleGenerativeAIEmbeddings(model=model_name, **kwargs)
|
||||||
if provider == "google_vertexai":
|
if provider == "google_vertexai":
|
||||||
from langchain_google_vertexai import VertexAIEmbeddings
|
from langchain_google_vertexai import VertexAIEmbeddings
|
||||||
|
|
||||||
|
|||||||
@@ -9,19 +9,20 @@ from langchain.embeddings.base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_model_string() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
("model_string", "expected_provider", "expected_model"),
|
||||||
|
[
|
||||||
|
("openai:text-embedding-3-small", "openai", "text-embedding-3-small"),
|
||||||
|
("bedrock:amazon.titan-embed-text-v1", "bedrock", "amazon.titan-embed-text-v1"),
|
||||||
|
("huggingface:BAAI/bge-base-en:v1.5", "huggingface", "BAAI/bge-base-en:v1.5"),
|
||||||
|
("google_genai:gemini-embedding-001", "google_genai", "gemini-embedding-001"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_model_string(model_string: str, expected_provider: str, expected_model: str) -> None:
|
||||||
"""Test parsing model strings into provider and model components."""
|
"""Test parsing model strings into provider and model components."""
|
||||||
assert _parse_model_string("openai:text-embedding-3-small") == (
|
assert _parse_model_string(model_string) == (
|
||||||
"openai",
|
expected_provider,
|
||||||
"text-embedding-3-small",
|
expected_model,
|
||||||
)
|
|
||||||
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
|
|
||||||
"bedrock",
|
|
||||||
"amazon.titan-embed-text-v1",
|
|
||||||
)
|
|
||||||
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
|
|
||||||
"huggingface",
|
|
||||||
"BAAI/bge-base-en:v1.5",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user