mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-15 20:12:30 +00:00
[Partner] Gemini Embeddings (#14690)
Add support for Gemini embeddings in the langchain-google-genai package
This commit is contained in:
parent
3449fce273
commit
1e21a3f7ed
220
docs/docs/integrations/text_embedding/google_generative_ai.ipynb
Normal file
220
docs/docs/integrations/text_embedding/google_generative_ai.ipynb
Normal file
@ -0,0 +1,220 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "afab8b36-10bb-4795-bc98-75ab2d2081bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Google Generative AI Embeddings\n",
|
||||
"\n",
|
||||
"Connect to Google's generative AI embeddings service using the `GoogleGenerativeAIEmbeddings` class, found in the [langchain-google-genai](https://pypi.org/project/langchain-google-genai/) package."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "63545b38-9d56-4312-8f61-8d4f1e7a3b1b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d2f6a3cd-379f-4dff-a449-d3a9f3196f2a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -U langchain-google-genai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25f3f88e-164e-400d-b371-9fa488baba19",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Credentials"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec89153f-8999-4aab-a21b-0bfba1cc3893",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if \"GOOGLE_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"GOOGLE_API_KEY\"] = getpass(\"Provide your Google API key here\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f2437b22-e364-418a-8c13-490a026cb7b5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Usage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "eedc551e-a1f3-4fd8-8d65-4e0784c4441b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0.05636945, 0.0048285457, -0.0762591, -0.023642512, 0.05329321]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_google_genai import GoogleGenerativeAIEmbeddings\n",
|
||||
"\n",
|
||||
"embeddings = GoogleGenerativeAIEmbeddings(model=\"models/embedding-001\")\n",
|
||||
"vector = embeddings.embed_query(\"hello, world!\")\n",
|
||||
"vector[:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b2bed60-e7bd-4e48-83d6-1c87001f98bd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Batch\n",
|
||||
"\n",
|
||||
"You can also embed multiple strings at once for a processing speedup:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "6ec53aba-404f-4778-acd9-5d6664e79ed2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(3, 768)"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"vectors = embeddings.embed_documents(\n",
|
||||
" [\n",
|
||||
" \"Today is Monday\",\n",
|
||||
" \"Today is Tuesday\",\n",
|
||||
" \"Today is April Fools day\",\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"len(vectors), len(vectors[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1482486f-5617-498a-8a44-1974d3212dda",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Task type\n",
|
||||
"`GoogleGenerativeAIEmbeddings` optionally support a `task_type`, which currently must be one of:\n",
|
||||
"\n",
|
||||
"- task_type_unspecified\n",
|
||||
"- retrieval_query\n",
|
||||
"- retrieval_document\n",
|
||||
"- semantic_similarity\n",
|
||||
"- classification\n",
|
||||
"- clustering\n",
|
||||
"\n",
|
||||
"By default, we use `retrieval_document` in the `embed_documents` method and `retrieval_query` in the `embed_query` method. If you provide a task type, we will use that for all methods."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "a223bb25-2b1b-418e-a570-2f543083132e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%pip install --quiet matplotlib scikit-learn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "f1f077db-8eb4-49f7-8866-471a8528dcdb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_embeddings = GoogleGenerativeAIEmbeddings(\n",
|
||||
" model=\"models/embedding-001\", task_type=\"retrieval_query\"\n",
|
||||
")\n",
|
||||
"doc_embeddings = GoogleGenerativeAIEmbeddings(\n",
|
||||
" model=\"models/embedding-001\", task_type=\"retrieval_document\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "79bd4a5e-75ba-413c-befa-86167c938caf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"All of these will be embedded with the 'retrieval_query' task set\n",
|
||||
"```python\n",
|
||||
"query_vecs = [query_embeddings.embed_query(q) for q in [query, query_2, answer_1]]\n",
|
||||
"```\n",
|
||||
"All of these will be embedded with the 'retrieval_document' task set\n",
|
||||
"```python\n",
|
||||
"doc_vecs = [doc_embeddings.embed_query(q) for q in [query, query_2, answer_1]]\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9e1fae5e-0f84-4812-89f5-7d4d71affbc1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In retrieval, relative distance matters. In the image above, you can see the difference in similarity scores between the \"relevant doc\" and \"simil stronger delta between the similar query and relevant doc on the latter case."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -56,3 +56,16 @@ The value of `image_url` can be any of the following:
|
||||
- A local file path
|
||||
- A base64 encoded image (e.g., `data:image/png;base64,abcd124`)
|
||||
- A PIL image
|
||||
|
||||
|
||||
|
||||
## Embeddings
|
||||
|
||||
This package also adds support for google's embeddings models.
|
||||
|
||||
```
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("hello, world!")
|
||||
```
|
@ -1,3 +1,46 @@
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
"""**LangChain Google Generative AI Integration**
|
||||
|
||||
__all__ = ["ChatGoogleGenerativeAI"]
|
||||
This module integrates Google's Generative AI models, specifically the Gemini series, with the LangChain framework. It provides classes for interacting with chat models and generating embeddings, leveraging Google's advanced AI capabilities.
|
||||
|
||||
**Chat Models**
|
||||
|
||||
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
|
||||
|
||||
**Embeddings**
|
||||
|
||||
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
|
||||
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
|
||||
|
||||
**Installation**
|
||||
|
||||
To install the package, use pip:
|
||||
|
||||
```python
|
||||
pip install -U langchain-google-genai
|
||||
```
|
||||
## Using Chat Models
|
||||
|
||||
After setting up your environment with the required API key, you can interact with the Google Gemini models.
|
||||
|
||||
```python
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
|
||||
|
||||
```python
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("hello, world!")
|
||||
```
|
||||
""" # noqa: E501
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
|
||||
__all__ = ["ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings"]
|
||||
|
@ -0,0 +1,4 @@
|
||||
class GoogleGenerativeAIError(Exception):
|
||||
"""
|
||||
Custom exception class for errors associated with the `Google GenAI` API.
|
||||
"""
|
@ -5,7 +5,6 @@ import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
@ -22,6 +21,8 @@ from typing import (
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -38,7 +39,7 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@ -48,11 +49,8 @@ from tenacity import (
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
IMAGE_TYPES: Tuple = ()
|
||||
try:
|
||||
import PIL
|
||||
@ -63,8 +61,10 @@ except ImportError:
|
||||
PIL = None # type: ignore
|
||||
Image = None # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ChatGoogleGenerativeAIError(Exception):
|
||||
|
||||
class ChatGoogleGenerativeAIError(GoogleGenerativeAIError):
|
||||
"""
|
||||
Custom exception class for errors associated with the `Google GenAI` API.
|
||||
|
||||
@ -106,7 +106,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
)
|
||||
|
||||
|
||||
def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
@ -139,7 +139,7 @@ def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
return _chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
@ -269,8 +269,6 @@ def _convert_to_parts(
|
||||
content: Sequence[Union[str, dict]],
|
||||
) -> List[genai.types.PartType]:
|
||||
"""Converts a list of LangChain messages into a google parts."""
|
||||
import google.generativeai as genai
|
||||
|
||||
parts = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
@ -410,8 +408,7 @@ def _response_to_result(
|
||||
class ChatGoogleGenerativeAI(BaseChatModel):
|
||||
"""`Google Generative AI` Chat models API.
|
||||
|
||||
To use you must have the google.generativeai Python package installed and
|
||||
either:
|
||||
To use, you must have either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
@ -435,7 +432,7 @@ Supported examples:
|
||||
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[str] = None
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must by in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
@ -487,17 +484,9 @@ Supported examples:
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
except ImportError:
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Could not import google.generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`"
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
genai.configure(api_key=google_api_key)
|
||||
if (
|
||||
values.get("temperature") is not None
|
||||
and not 0 <= values["temperature"] <= 1
|
||||
@ -560,7 +549,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = chat_with_retry(
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
)
|
||||
@ -574,7 +563,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = await achat_with_retry(
|
||||
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
)
|
||||
@ -588,7 +577,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = chat_with_retry(
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
stream=True,
|
||||
@ -614,7 +603,7 @@ Supported examples:
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
async for chunk in await achat_with_retry(
|
||||
async for chunk in await _achat_with_retry(
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
stream=True,
|
||||
|
@ -0,0 +1,99 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
|
||||
|
||||
class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
|
||||
"""`Google Generative AI Embeddings`.
|
||||
|
||||
To use, you must have either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
|
||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
||||
embeddings.embed_query("What's our Q1 revenue?")
|
||||
"""
|
||||
|
||||
model: str = Field(
|
||||
...,
|
||||
description="The name of the embedding model to use. "
|
||||
"Example: models/embedding-001",
|
||||
)
|
||||
task_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The task type. Valid options include: "
|
||||
"task_type_unspecified, retrieval_query, retrieval_document, "
|
||||
"semantic_similarity, classification, and clustering",
|
||||
)
|
||||
google_api_key: Optional[SecretStr] = Field(
|
||||
None,
|
||||
description="The Google API key to use. If not provided, "
|
||||
"the GOOGLE_API_KEY environment variable will be used.",
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validates that the python package exists in environment."""
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
if isinstance(google_api_key, SecretStr):
|
||||
google_api_key = google_api_key.get_secret_value()
|
||||
genai.configure(api_key=google_api_key)
|
||||
return values
|
||||
|
||||
def _embed(
|
||||
self, texts: List[str], task_type: str, title: Optional[str] = None
|
||||
) -> List[List[float]]:
|
||||
task_type = self.task_type or "retrieval_document"
|
||||
try:
|
||||
result = genai.embed_content(
|
||||
model=self.model,
|
||||
content=texts,
|
||||
task_type=task_type,
|
||||
title=title,
|
||||
)
|
||||
except Exception as e:
|
||||
raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e
|
||||
return result["embedding"]
|
||||
|
||||
def embed_documents(
|
||||
self, texts: List[str], batch_size: int = 5
|
||||
) -> List[List[float]]:
|
||||
"""Embed a list of strings. Vertex AI currently
|
||||
sets a max batch size of 5 strings.
|
||||
|
||||
Args:
|
||||
texts: List[str] The list of strings to embed.
|
||||
batch_size: [int] The batch size of embeddings to send to the model
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
task_type = self.task_type or "retrieval_document"
|
||||
return self._embed(texts, task_type=task_type)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
task_type = self.task_type or "retrieval_query"
|
||||
return self._embed([text], task_type=task_type)[0]
|
58
libs/partners/google-genai/poetry.lock
generated
58
libs/partners/google-genai/poetry.lock
generated
@ -441,7 +441,6 @@ optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||
files = [
|
||||
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
|
||||
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -546,6 +545,51 @@ files = [
|
||||
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.2"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "numpy-1.26.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3703fc9258a4a122d17043e57b35e5ef1c5a5837c3db8be396c82e04c1cf9b0f"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cc392fdcbd21d4be6ae1bb4475a03ce3b025cd49a9be5345d76d7585aea69440"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36340109af8da8805d8851ef1d74761b3b88e81a9bd80b290bbfed61bd2b4f75"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcc008217145b3d77abd3e4d5ef586e3bdfba8fe17940769f8aa09b99e856c00"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ced40d4e9e18242f70dd02d739e44698df3dcb010d31f495ff00a31ef6014fe"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b272d4cecc32c9e19911891446b72e986157e6a1809b7b56518b4f3755267523"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-win32.whl", hash = "sha256:22f8fc02fdbc829e7a8c578dd8d2e15a9074b630d4da29cda483337e300e3ee9"},
|
||||
{file = "numpy-1.26.2-cp310-cp310-win_amd64.whl", hash = "sha256:26c9d33f8e8b846d5a65dd068c14e04018d05533b348d9eaeef6c1bd787f9919"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b96e7b9c624ef3ae2ae0e04fa9b460f6b9f17ad8b4bec6d7756510f1f6c0c841"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aa18428111fb9a591d7a9cc1b48150097ba6a7e8299fb56bdf574df650e7d1f1"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06fa1ed84aa60ea6ef9f91ba57b5ed963c3729534e6e54055fc151fad0423f0a"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ca5482c3dbdd051bcd1fce8034603d6ebfc125a7bd59f55b40d8f5d246832b"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:854ab91a2906ef29dc3925a064fcd365c7b4da743f84b123002f6139bcb3f8a7"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f43740ab089277d403aa07567be138fc2a89d4d9892d113b76153e0e412409f8"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-win32.whl", hash = "sha256:a2bbc29fcb1771cd7b7425f98b05307776a6baf43035d3b80c4b0f29e9545186"},
|
||||
{file = "numpy-1.26.2-cp311-cp311-win_amd64.whl", hash = "sha256:2b3fca8a5b00184828d12b073af4d0fc5fdd94b1632c2477526f6bd7842d700d"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a4cd6ed4a339c21f1d1b0fdf13426cb3b284555c27ac2f156dfdaaa7e16bfab0"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5d5244aabd6ed7f312268b9247be47343a654ebea52a60f002dc70c769048e75"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a3cdb4d9c70e6b8c0814239ead47da00934666f668426fc6e94cce869e13fd7"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa317b2325f7aa0a9471663e6093c210cb2ae9c0ad824732b307d2c51983d5b6"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:174a8880739c16c925799c018f3f55b8130c1f7c8e75ab0a6fa9d41cab092fd6"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f79b231bf5c16b1f39c7f4875e1ded36abee1591e98742b05d8a0fb55d8a3eec"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-win32.whl", hash = "sha256:4a06263321dfd3598cacb252f51e521a8cb4b6df471bb12a7ee5cbab20ea9167"},
|
||||
{file = "numpy-1.26.2-cp312-cp312-win_amd64.whl", hash = "sha256:b04f5dc6b3efdaab541f7857351aac359e6ae3c126e2edb376929bd3b7f92d7e"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4eb8df4bf8d3d90d091e0146f6c28492b0be84da3e409ebef54349f71ed271ef"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a13860fdcd95de7cf58bd6f8bc5a5ef81c0b0625eb2c9a783948847abbef2c2"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64308ebc366a8ed63fd0bf426b6a9468060962f1a4339ab1074c228fa6ade8e3"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baf8aab04a2c0e859da118f0b38617e5ee65d75b83795055fb66c0d5e9e9b818"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d73a3abcac238250091b11caef9ad12413dab01669511779bc9b29261dd50210"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b361d369fc7e5e1714cf827b731ca32bff8d411212fccd29ad98ad622449cc36"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-win32.whl", hash = "sha256:bd3f0091e845164a20bd5a326860c840fe2af79fa12e0469a12768a3ec578d80"},
|
||||
{file = "numpy-1.26.2-cp39-cp39-win_amd64.whl", hash = "sha256:2beef57fb031dcc0dc8fa4fe297a742027b954949cabb52a2a376c144e5e6060"},
|
||||
{file = "numpy-1.26.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1cc3d5029a30fb5f06704ad6b23b35e11309491c999838c31f124fee32107c79"},
|
||||
{file = "numpy-1.26.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94cc3c222bb9fb5a12e334d0479b97bb2df446fbe622b470928f5284ffca3f8d"},
|
||||
{file = "numpy-1.26.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe6b44fb8fcdf7eda4ef4461b97b3f63c466b27ab151bec2366db8b197387841"},
|
||||
{file = "numpy-1.26.2.tar.gz", hash = "sha256:f65738447676ab5777f11e6bbbdb8ce11b785e105f690bc45966574816b6d3ea"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "23.2"
|
||||
@ -935,7 +979,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
@ -943,15 +986,8 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
@ -968,7 +1004,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
@ -976,7 +1011,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
@ -1229,4 +1263,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "7753b9e2cb62c5b4dac124f0ff43027232c45138dbf07fdacc3c320b82367dad"
|
||||
content-hash = "ec0b5e3da951c44178eac11414611121ed2783d04b8957de8f6a189b5a6bcc2b"
|
||||
|
@ -1,9 +1,10 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-google-genai"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = "An integration package connecting Google's genai package and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain/blob/master/libs/partners/google-genai"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
@ -16,11 +17,12 @@ optional = true
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
numpy = "^1.26.2"
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
@ -41,7 +43,7 @@ ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
types-requests = "^2.28.11.5"
|
||||
types-google-cloud-ndb = "^2.2.0.1"
|
||||
types-pillow = "^10.1.0.2"
|
||||
@ -50,7 +52,7 @@ types-pillow = "^10.1.0.2"
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
pillow = "^10.1.0"
|
||||
types-requests = "^2.31.0.10"
|
||||
types-pillow = "^10.1.0.2"
|
||||
@ -58,19 +60,16 @@ types-google-cloud-ndb = "^2.2.0.1"
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
|
||||
_MODEL = "models/embedding-001"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"Hi",
|
||||
"This is a longer query string to test the embedding functionality of the"
|
||||
" model against the pickle rick?",
|
||||
],
|
||||
)
|
||||
def test_embed_query_different_lengths(query: str) -> None:
|
||||
"""Test embedding queries of different lengths."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
result = model.embed_query(query)
|
||||
assert len(result) == 768
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query",
|
||||
[
|
||||
"Hi",
|
||||
"This is a longer query string to test the embedding functionality of the"
|
||||
" model against the pickle rick?",
|
||||
],
|
||||
)
|
||||
async def test_aembed_query_different_lengths(query: str) -> None:
|
||||
"""Test embedding queries of different lengths."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
result = await model.aembed_query(query)
|
||||
assert len(result) == 768
|
||||
|
||||
|
||||
def test_embed_documents() -> None:
|
||||
"""Test embedding a query."""
|
||||
model = GoogleGenerativeAIEmbeddings(
|
||||
model=_MODEL,
|
||||
)
|
||||
result = model.embed_documents(["Hello world", "Good day, world"])
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 768
|
||||
assert len(result[1]) == 768
|
||||
|
||||
|
||||
async def test_aembed_documents() -> None:
|
||||
"""Test embedding a query."""
|
||||
model = GoogleGenerativeAIEmbeddings(
|
||||
model=_MODEL,
|
||||
)
|
||||
result = await model.aembed_documents(["Hello world", "Good day, world"])
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 768
|
||||
assert len(result[1]) == 768
|
||||
|
||||
|
||||
def test_invalid_model_error_handling() -> None:
|
||||
"""Test error handling with an invalid model name."""
|
||||
with pytest.raises(GoogleGenerativeAIError):
|
||||
GoogleGenerativeAIEmbeddings(model="invalid_model").embed_query("Hello world")
|
||||
|
||||
|
||||
def test_invalid_api_key_error_handling() -> None:
|
||||
"""Test error handling with an invalid API key."""
|
||||
with pytest.raises(GoogleGenerativeAIError):
|
||||
GoogleGenerativeAIEmbeddings(
|
||||
model=_MODEL, google_api_key="invalid_key"
|
||||
).embed_query("Hello world")
|
||||
|
||||
|
||||
def test_embed_documents_consistency() -> None:
|
||||
"""Test embedding consistency for the same document."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
doc = "Consistent document for testing"
|
||||
result1 = model.embed_documents([doc])
|
||||
result2 = model.embed_documents([doc])
|
||||
assert result1 == result2
|
||||
|
||||
|
||||
def test_embed_documents_quality() -> None:
|
||||
"""Smoke test embedding quality by comparing similar and dissimilar documents."""
|
||||
model = GoogleGenerativeAIEmbeddings(model=_MODEL)
|
||||
similar_docs = ["Document A", "Similar Document A"]
|
||||
dissimilar_docs = ["Document A", "Completely Different Zebra"]
|
||||
similar_embeddings = model.embed_documents(similar_docs)
|
||||
dissimilar_embeddings = model.embed_documents(dissimilar_docs)
|
||||
similar_distance = np.linalg.norm(
|
||||
np.array(similar_embeddings[0]) - np.array(similar_embeddings[1])
|
||||
)
|
||||
dissimilar_distance = np.linalg.norm(
|
||||
np.array(dissimilar_embeddings[0]) - np.array(dissimilar_embeddings[1])
|
||||
)
|
||||
assert similar_distance < dissimilar_distance
|
@ -1,5 +1,6 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
|
||||
@ -22,3 +23,16 @@ def test_integration_initialization() -> None:
|
||||
temperature=0.7,
|
||||
candidate_count=2,
|
||||
)
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
|
||||
assert isinstance(chat.google_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
||||
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
|
||||
print(chat.google_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
@ -0,0 +1,37 @@
|
||||
"""Test embeddings model integration."""
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture
|
||||
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
|
||||
|
||||
def test_integration_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="...",
|
||||
)
|
||||
GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="...",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
|
||||
|
||||
def test_api_key_is_string() -> None:
|
||||
embeddings = GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="secret-api-key",
|
||||
)
|
||||
assert isinstance(embeddings.google_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
|
||||
embeddings = GoogleGenerativeAIEmbeddings(
|
||||
model="models/embedding-001",
|
||||
google_api_key="secret-api-key",
|
||||
)
|
||||
print(embeddings.google_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
@ -2,6 +2,7 @@ from langchain_google_genai import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
"GoogleGenerativeAIEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user