mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
[Partner] Add langchain-google-genai package (gemini) (#14621)
Add a new ChatGoogleGenerativeAI class in a `langchain-google-genai` package. Still todo: add a deprecation warning in PALM --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
4574749147
commit
405d111da6
1
.github/workflows/_release.yml
vendored
1
.github/workflows/_release.yml
vendored
@ -18,6 +18,7 @@ on:
|
||||
- libs/core
|
||||
- libs/experimental
|
||||
- libs/community
|
||||
- libs/partners/google-genai
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
1
libs/partners/google-genai/.gitignore
vendored
Normal file
1
libs/partners/google-genai/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/google-genai/LICENSE
Normal file
21
libs/partners/google-genai/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
61
libs/partners/google-genai/Makefile
Normal file
61
libs/partners/google-genai/Makefile
Normal file
@ -0,0 +1,61 @@
|
||||
.PHONY: all format lint test tests integration_tests help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
check_imports: $(shell find langchain_google_genai -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
integration_tests:
|
||||
poetry run pytest tests/integration_tests
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_google_genai
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
./scripts/check_pydantic.sh .
|
||||
./scripts/lint_imports.sh
|
||||
poetry run ruff .
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
58
libs/partners/google-genai/README.md
Normal file
58
libs/partners/google-genai/README.md
Normal file
@ -0,0 +1,58 @@
|
||||
# langchain-google-genai
|
||||
|
||||
This package contains the LangChain integrations for Gemini through their generative-ai SDK.
|
||||
|
||||
## Installation
|
||||
|
||||
```python
|
||||
pip install -U langchain-google-genai
|
||||
```
|
||||
|
||||
## Chat Models
|
||||
|
||||
This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
|
||||
|
||||
To use, install the requirements, and configure your environment.
|
||||
|
||||
```bash
|
||||
export GOOGLE_API_KEY=your-api-key
|
||||
```
|
||||
|
||||
Then initialize
|
||||
|
||||
```python
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
||||
|
||||
#### Multimodal inputs
|
||||
|
||||
Gemini vision model supports image inputs when providing a single chat message. Example:
|
||||
|
||||
```
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
|
||||
# example
|
||||
message = HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
}, # You can optionally provide text parts
|
||||
{"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
|
||||
]
|
||||
)
|
||||
llm.invoke([message])
|
||||
```
|
||||
|
||||
The value of `image_url` can be any of the following:
|
||||
|
||||
- A public image URL
|
||||
- An accessible gcs file (e.g., "gcs://path/to/file.png")
|
||||
- A local file path
|
||||
- A base64 encoded image (e.g., ``)
|
||||
- A PIL image
|
@ -0,0 +1,3 @@
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
|
||||
__all__ = ["ChatGoogleGenerativeAI"]
|
632
libs/partners/google-genai/langchain_google_genai/chat_models.py
Normal file
632
libs/partners/google-genai/langchain_google_genai/chat_models.py
Normal file
@ -0,0 +1,632 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
IMAGE_TYPES: Tuple = ()
|
||||
try:
|
||||
import PIL
|
||||
from PIL.Image import Image
|
||||
|
||||
IMAGE_TYPES = IMAGE_TYPES + (Image,)
|
||||
except ImportError:
|
||||
PIL = None # type: ignore
|
||||
Image = None # type: ignore
|
||||
|
||||
|
||||
class ChatGoogleGenerativeAIError(Exception):
|
||||
"""
|
||||
Custom exception class for errors associated with the `Google GenAI` API.
|
||||
|
||||
This exception is raised when there are specific issues related to the
|
||||
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
|
||||
message types or roles.
|
||||
"""
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""
|
||||
Creates and returns a preconfigured tenacity retry decorator.
|
||||
|
||||
The retry decorator is configured to handle specific Google API exceptions
|
||||
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
|
||||
backoff strategy for retries.
|
||||
|
||||
Returns:
|
||||
Callable[[Any], Any]: A retry decorator configured for handling specific
|
||||
Google API exceptions.
|
||||
"""
|
||||
import google.api_core.exceptions
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
This function is a wrapper that applies a retry mechanism to a provided
|
||||
chat generation function. It is useful for handling intermittent issues
|
||||
like network errors or temporary service unavailability.
|
||||
|
||||
Args:
|
||||
generation_method (Callable): The chat generation method to be executed.
|
||||
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
||||
|
||||
Returns:
|
||||
Any: The result from the chat generation method.
|
||||
"""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
from google.api_core.exceptions import InvalidArgument # type: ignore
|
||||
|
||||
@retry_decorator
|
||||
def _chat_with_retry(**kwargs: Any) -> Any:
|
||||
try:
|
||||
return generation_method(**kwargs)
|
||||
except InvalidArgument as e:
|
||||
# Do not retry for these errors.
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
f"Invalid argument provided to Gemini: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return _chat_with_retry(**kwargs)
|
||||
|
||||
|
||||
async def achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes a chat generation method with retry logic using tenacity.
|
||||
|
||||
This function is a wrapper that applies a retry mechanism to a provided
|
||||
chat generation function. It is useful for handling intermittent issues
|
||||
like network errors or temporary service unavailability.
|
||||
|
||||
Args:
|
||||
generation_method (Callable): The chat generation method to be executed.
|
||||
**kwargs (Any): Additional keyword arguments to pass to the generation method.
|
||||
|
||||
Returns:
|
||||
Any: The result from the chat generation method.
|
||||
"""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
from google.api_core.exceptions import InvalidArgument # type: ignore
|
||||
|
||||
@retry_decorator
|
||||
async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||
try:
|
||||
return await generation_method(**kwargs)
|
||||
except InvalidArgument as e:
|
||||
# Do not retry for these errors.
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
f"Invalid argument provided to Gemini: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return await _achat_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _get_role(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
if message.role not in ("user", "model"):
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Gemini only supports user and model roles when"
|
||||
" providing it with Chat messages."
|
||||
)
|
||||
return message.role
|
||||
elif isinstance(message, HumanMessage):
|
||||
return "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
return "model"
|
||||
else:
|
||||
# TODO: Gemini doesn't seem to have a concept of system messages yet.
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
f"Message of '{message.type}' type not supported by Gemini."
|
||||
" Please only provide it with Human or AI (user/assistant) messages."
|
||||
)
|
||||
|
||||
|
||||
def _is_openai_parts_format(part: dict) -> bool:
|
||||
return "type" in part
|
||||
|
||||
|
||||
def _is_vision_model(model: str) -> bool:
|
||||
return "vision" in model
|
||||
|
||||
|
||||
def _is_url(s: str) -> bool:
|
||||
try:
|
||||
result = urlparse(s)
|
||||
return all([result.scheme, result.netloc])
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to parse URL: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_b64(s: str) -> bool:
|
||||
return s.startswith("data:image")
|
||||
|
||||
|
||||
def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
|
||||
try:
|
||||
from google.cloud import storage # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-cloud-storage is required to load images from GCS."
|
||||
" Install it with `pip install google-cloud-storage`"
|
||||
)
|
||||
if PIL is None:
|
||||
raise ImportError(
|
||||
"PIL is required to load images. Please install it "
|
||||
"with `pip install pillow`"
|
||||
)
|
||||
|
||||
gcs_client = storage.Client(project=project)
|
||||
pieces = path.split("/")
|
||||
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
|
||||
if len(blobs) > 1:
|
||||
raise ValueError(f"Found more than one candidate for {path}!")
|
||||
img_bytes = blobs[0].download_as_bytes()
|
||||
return PIL.Image.open(BytesIO(img_bytes))
|
||||
|
||||
|
||||
def _url_to_pil(image_source: str) -> Image:
|
||||
if PIL is None:
|
||||
raise ImportError(
|
||||
"PIL is required to load images. Please install it "
|
||||
"with `pip install pillow`"
|
||||
)
|
||||
try:
|
||||
if isinstance(image_source, IMAGE_TYPES):
|
||||
return image_source # type: ignore[return-value]
|
||||
elif _is_url(image_source):
|
||||
if image_source.startswith("gs://"):
|
||||
return _load_image_from_gcs(image_source)
|
||||
response = requests.get(image_source)
|
||||
response.raise_for_status()
|
||||
return PIL.Image.open(BytesIO(response.content))
|
||||
elif _is_b64(image_source):
|
||||
_, encoded = image_source.split(",", 1)
|
||||
data = base64.b64decode(encoded)
|
||||
return PIL.Image.open(BytesIO(data))
|
||||
elif os.path.exists(image_source):
|
||||
return PIL.Image.open(image_source)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The provided string is not a valid URL, base64, or file path."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to process the provided image source: {e}")
|
||||
|
||||
|
||||
def _convert_to_parts(
|
||||
content: Sequence[Union[str, dict]],
|
||||
) -> List[genai.types.PartType]:
|
||||
"""Converts a list of LangChain messages into a google parts."""
|
||||
import google.generativeai as genai
|
||||
|
||||
parts = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
parts.append(genai.types.PartDict(text=part, inline_data=None))
|
||||
elif isinstance(part, Mapping):
|
||||
# OpenAI Format
|
||||
if _is_openai_parts_format(part):
|
||||
if part["type"] == "text":
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
img_url = part["image_url"]
|
||||
if isinstance(img_url, dict):
|
||||
if "url" not in img_url:
|
||||
raise ValueError(
|
||||
f"Unrecognized message image format: {img_url}"
|
||||
)
|
||||
img_url = img_url["url"]
|
||||
parts.append({"inline_data": _url_to_pil(img_url)})
|
||||
else:
|
||||
raise ValueError(f"Unrecognized message part type: {part['type']}")
|
||||
else:
|
||||
# Yolo
|
||||
logger.warning(
|
||||
"Unrecognized message part format. Assuming it's a text part."
|
||||
)
|
||||
parts.append(part)
|
||||
else:
|
||||
# TODO: Maybe some of Google's native stuff
|
||||
# would hit this branch.
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Gemini only supports text and inline_data parts."
|
||||
)
|
||||
return parts
|
||||
|
||||
|
||||
def _messages_to_genai_contents(
|
||||
input_messages: Sequence[BaseMessage],
|
||||
) -> List[genai.types.ContentDict]:
|
||||
"""Converts a list of messages into a Gemini API google content dicts."""
|
||||
|
||||
messages: List[genai.types.MessageDict] = []
|
||||
for i, message in enumerate(input_messages):
|
||||
role = _get_role(message)
|
||||
if isinstance(message.content, str):
|
||||
parts = [message.content]
|
||||
else:
|
||||
parts = _convert_to_parts(message.content)
|
||||
messages.append({"role": role, "parts": parts})
|
||||
if i > 0:
|
||||
# Cannot have multiple messages from the same role in a row.
|
||||
if role == messages[-2]["role"]:
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Cannot have multiple messages from the same role in a row."
|
||||
" Consider merging them into a single message with multiple"
|
||||
f" parts.\nReceived: {messages}"
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]:
|
||||
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
||||
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
||||
# Simple text response. The typical response
|
||||
return parts[0].text
|
||||
elif not parts:
|
||||
logger.warning("Gemini produced an empty response.")
|
||||
return ""
|
||||
messages = []
|
||||
for part in parts:
|
||||
if part.text is not None:
|
||||
messages.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": part.text,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# TODO: Handle inline_data if that's a thing?
|
||||
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
||||
return messages
|
||||
|
||||
|
||||
def _response_to_result(
|
||||
response: genai.types.GenerateContentResponse,
|
||||
ai_msg_t: Type[BaseMessage] = AIMessage,
|
||||
human_msg_t: Type[BaseMessage] = HumanMessage,
|
||||
chat_msg_t: Type[BaseMessage] = ChatMessage,
|
||||
generation_t: Type[ChatGeneration] = ChatGeneration,
|
||||
) -> ChatResult:
|
||||
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||
llm_output = {}
|
||||
if response.prompt_feedback:
|
||||
try:
|
||||
prompt_feedback = type(response.prompt_feedback).to_dict(
|
||||
response.prompt_feedback, use_integers_for_enums=False
|
||||
)
|
||||
llm_output["prompt_feedback"] = prompt_feedback
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
|
||||
|
||||
generations: List[ChatGeneration] = []
|
||||
|
||||
role_map = {
|
||||
"model": ai_msg_t,
|
||||
"user": human_msg_t,
|
||||
}
|
||||
for candidate in response.candidates:
|
||||
content = candidate.content
|
||||
parts_content = _parts_to_content(content.parts)
|
||||
if content.role not in role_map:
|
||||
logger.warning(
|
||||
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
||||
)
|
||||
msg = chat_msg_t(content=parts_content, role=content.role)
|
||||
else:
|
||||
msg = role_map[content.role](content=parts_content)
|
||||
generation_info = {}
|
||||
if candidate.finish_reason:
|
||||
generation_info["finish_reason"] = candidate.finish_reason.name
|
||||
if candidate.safety_ratings:
|
||||
generation_info["safety_ratings"] = [
|
||||
type(rating).to_dict(rating) for rating in candidate.safety_ratings
|
||||
]
|
||||
generations.append(generation_t(message=msg, generation_info=generation_info))
|
||||
if not response.candidates:
|
||||
# Likely a "prompt feedback" violation (e.g., toxic input)
|
||||
# Raising an error would be different than how OpenAI handles it,
|
||||
# so we'll just log a warning and continue with an empty message.
|
||||
logger.warning(
|
||||
"Gemini produced an empty response. Continuing with empty message\n"
|
||||
f"Feedback: {response.prompt_feedback}"
|
||||
)
|
||||
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
class ChatGoogleGenerativeAI(BaseChatModel):
|
||||
"""`Google Generative AI` Chat models API.
|
||||
|
||||
To use you must have the google.generativeai Python package installed and
|
||||
either:
|
||||
|
||||
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
|
||||
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
|
||||
constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
chat = ChatGoogleGenerativeAI(model="gemini-pro")
|
||||
chat.invoke("Write me a ballad about LangChain")
|
||||
|
||||
"""
|
||||
|
||||
model: str = Field(
|
||||
...,
|
||||
description="""The name of the model to use.
|
||||
Supported examples:
|
||||
- gemini-pro""",
|
||||
)
|
||||
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must by in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
top_p: Optional[int] = None
|
||||
"""The maximum cumulative probability of tokens to consider when sampling.
|
||||
|
||||
The model uses combined Top-k and nucleus sampling.
|
||||
|
||||
Tokens are sorted based on their assigned probabilities so
|
||||
that only the most likely tokens are considered. Top-k
|
||||
sampling directly limits the maximum number of tokens to
|
||||
consider, while Nucleus sampling limits number of tokens
|
||||
based on the cumulative probability.
|
||||
|
||||
Note: The default value varies by model, see the
|
||||
`Model.top_p` attribute of the `Model` returned the
|
||||
`genai.get_model` function.
|
||||
"""
|
||||
n: int = Field(default=1, alias="candidate_count")
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
|
||||
_generative_model: Any #: :meta private:
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "chat-google-generative-ai"
|
||||
|
||||
@property
|
||||
def _is_geminiai(self) -> bool:
|
||||
return self.model is not None and "gemini" in self.model
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
google_api_key = get_from_dict_or_env(
|
||||
values, "google_api_key", "GOOGLE_API_KEY"
|
||||
)
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key=google_api_key)
|
||||
except ImportError:
|
||||
raise ChatGoogleGenerativeAIError(
|
||||
"Could not import google.generativeai python package. "
|
||||
"Please install it with `pip install google-generativeai`"
|
||||
)
|
||||
|
||||
values["client"] = genai
|
||||
if (
|
||||
values.get("temperature") is not None
|
||||
and not 0 <= values["temperature"] <= 1
|
||||
):
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
if values.get("top_k") is not None and values["top_k"] <= 0:
|
||||
raise ValueError("top_k must be positive")
|
||||
model = values["model"]
|
||||
values["_generative_model"] = genai.GenerativeModel(model_name=model)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"n": self.n,
|
||||
}
|
||||
|
||||
@property
|
||||
def _generation_method(self) -> Callable:
|
||||
return self._generative_model.generate_content
|
||||
|
||||
@property
|
||||
def _async_generation_method(self) -> Callable:
|
||||
return self._generative_model.generate_content_async
|
||||
|
||||
def _prepare_params(
|
||||
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
contents = _messages_to_genai_contents(messages)
|
||||
gen_config = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"candidate_count": self.n,
|
||||
"temperature": self.temperature,
|
||||
"stop_sequences": stop,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
if "generation_config" in kwargs:
|
||||
gen_config = {**gen_config, **kwargs.pop("generation_config")}
|
||||
params = {"generation_config": gen_config, "contents": contents, **kwargs}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = chat_with_retry(
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
)
|
||||
return _response_to_result(response)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = await achat_with_retry(
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
)
|
||||
return _response_to_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
response: genai.types.GenerateContentResponse = chat_with_retry(
|
||||
**params,
|
||||
generation_method=self._generation_method,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
_chat_result = _response_to_result(
|
||||
chunk,
|
||||
ai_msg_t=AIMessageChunk,
|
||||
human_msg_t=HumanMessageChunk,
|
||||
chat_msg_t=ChatMessageChunk,
|
||||
generation_t=ChatGenerationChunk,
|
||||
)
|
||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||
yield gen
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(gen.text)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._prepare_params(messages, stop, **kwargs)
|
||||
async for chunk in await achat_with_retry(
|
||||
**params,
|
||||
generation_method=self._async_generation_method,
|
||||
stream=True,
|
||||
):
|
||||
_chat_result = _response_to_result(
|
||||
chunk,
|
||||
ai_msg_t=AIMessageChunk,
|
||||
human_msg_t=HumanMessageChunk,
|
||||
chat_msg_t=ChatMessageChunk,
|
||||
generation_t=ChatGenerationChunk,
|
||||
)
|
||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||
yield gen
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(gen.text)
|
1232
libs/partners/google-genai/poetry.lock
generated
Normal file
1232
libs/partners/google-genai/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
97
libs/partners/google-genai/pyproject.toml
Normal file
97
libs/partners/google-genai/pyproject.toml
Normal file
@ -0,0 +1,97 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-google-genai"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting Google's genai package and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
langchain-core = "^0.1"
|
||||
google-generativeai = "^0.3.1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
types-requests = "^2.28.11.5"
|
||||
types-google-cloud-ndb = "^2.2.0.1"
|
||||
types-pillow = "^10.1.0.2"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = {path = "../../core", develop = true}
|
||||
pillow = "^10.1.0"
|
||||
types-requests = "^2.31.0.10"
|
||||
types-pillow = "^10.1.0.2"
|
||||
types-google-cloud-ndb = "^2.2.0.1"
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/google-genai/scripts/check_imports.py
Normal file
17
libs/partners/google-genai/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/google-genai/scripts/check_pydantic.sh
Executable file
27
libs/partners/google-genai/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
17
libs/partners/google-genai/scripts/lint_imports.sh
Executable file
17
libs/partners/google-genai/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/google-genai/tests/__init__.py
Normal file
0
libs/partners/google-genai/tests/__init__.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Test ChatGoogleGenerativeAI chat model."""
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_google_genai.chat_models import (
|
||||
ChatGoogleGenerativeAI,
|
||||
ChatGoogleGenerativeAIError,
|
||||
)
|
||||
|
||||
_MODEL = "gemini-pro" # TODO: Use nano when it's available.
|
||||
_VISION_MODEL = "gemini-pro-vision"
|
||||
_B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501
|
||||
|
||||
|
||||
def test_chat_google_genai_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_chat_google_genai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_chat_google_genai_abatch() -> None:
|
||||
"""Test streaming tokens from ChatGoogleGenerativeAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_chat_google_genai_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatGoogleGenerativeAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_chat_google_genai_batch() -> None:
|
||||
"""Test batch tokens from ChatGoogleGenerativeAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_chat_google_genai_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatGoogleGenerativeAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_chat_google_genai_invoke() -> None:
|
||||
"""Test invoke tokens from ChatGoogleGenerativeAI."""
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
|
||||
result = llm.invoke(
|
||||
"I'm Pickle Rick",
|
||||
config=dict(tags=["foo"]),
|
||||
generation_config=dict(top_k=2, top_p=1, temperature=0.7),
|
||||
)
|
||||
assert isinstance(result.content, str)
|
||||
assert not result.content.startswith(" ")
|
||||
|
||||
|
||||
def test_chat_google_genai_invoke_multimodal() -> None:
|
||||
messages: list = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Guess what's in this picture! You have 3 guesses.",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/png;base64," + _B64_string,
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
|
||||
response = llm.invoke(messages)
|
||||
assert isinstance(response.content, str)
|
||||
assert len(response.content.strip()) > 0
|
||||
|
||||
# Try streaming
|
||||
for chunk in llm.stream(messages):
|
||||
print(chunk)
|
||||
assert isinstance(chunk.content, str)
|
||||
assert len(chunk.content.strip()) > 0
|
||||
|
||||
|
||||
def test_chat_google_genai_invoke_multimodal_too_many_messages() -> None:
|
||||
# Only supports 1 turn...
|
||||
messages: list = [
|
||||
HumanMessage(content="Hi there"),
|
||||
AIMessage(content="Hi, how are you?"),
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'm doing great! Guess what's in this picture!",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/png;base64," + _B64_string,
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
|
||||
with pytest.raises(ChatGoogleGenerativeAIError):
|
||||
llm.invoke(messages)
|
||||
|
||||
|
||||
def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
|
||||
# need the vision model to support this.
|
||||
messages: list = [
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'm doing great! Guess what's in this picture!",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/png;base64," + _B64_string,
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
llm = ChatGoogleGenerativeAI(model=_MODEL)
|
||||
with pytest.raises(ChatGoogleGenerativeAIError):
|
||||
llm.invoke(messages)
|
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -0,0 +1,24 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
|
||||
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
||||
|
||||
|
||||
def test_integration_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatGoogleGenerativeAI(
|
||||
model="gemini-nano",
|
||||
google_api_key="...",
|
||||
top_k=2,
|
||||
top_p=1,
|
||||
temperature=0.7,
|
||||
n=2,
|
||||
)
|
||||
ChatGoogleGenerativeAI(
|
||||
model="gemini-nano",
|
||||
google_api_key="...",
|
||||
top_k=2,
|
||||
top_p=1,
|
||||
temperature=0.7,
|
||||
candidate_count=2,
|
||||
)
|
@ -0,0 +1,9 @@
|
||||
from langchain_google_genai import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatGoogleGenerativeAI",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
Loading…
Reference in New Issue
Block a user