[Partner] Add langchain-google-genai package (gemini) (#14621)

Add a new ChatGoogleGenerativeAI class in a `langchain-google-genai`
package.
Still todo: add a deprecation warning in PALM

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
William FH 2023-12-13 11:57:59 -08:00 committed by GitHub
parent 4574749147
commit 405d111da6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 2356 additions and 0 deletions

View File

@ -18,6 +18,7 @@ on:
- libs/core
- libs/experimental
- libs/community
- libs/partners/google-genai
env:
PYTHON_VERSION: "3.10"

1
libs/partners/google-genai/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,61 @@
.PHONY: all format lint test tests integration_tests help
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
check_imports: $(shell find langchain_google_genai -name '*.py')
poetry run python ./scripts/check_imports.py $^
integration_tests:
poetry run pytest tests/integration_tests
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_google_genai
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
######################
# HELP
######################
help:
@echo '----'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -0,0 +1,58 @@
# langchain-google-genai
This package contains the LangChain integrations for Gemini through their generative-ai SDK.
## Installation
```python
pip install -U langchain-google-genai
```
## Chat Models
This package contains the `ChatGoogleGenerativeAI` class, which is the recommended way to interface with the Google Gemini series of models.
To use, install the requirements, and configure your environment.
```bash
export GOOGLE_API_KEY=your-api-key
```
Then initialize
```python
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
#### Multimodal inputs
Gemini vision model supports image inputs when providing a single chat message. Example:
```
from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-pro-vision")
# example
message = HumanMessage(
content=[
{
"type": "text",
"text": "What's in this image?",
}, # You can optionally provide text parts
{"type": "image_url", "image_url": "https://picsum.photos/seed/picsum/200/300"},
]
)
llm.invoke([message])
```
The value of `image_url` can be any of the following:
- A public image URL
- An accessible gcs file (e.g., "gcs://path/to/file.png")
- A local file path
- A base64 encoded image (e.g., ``)
- A PIL image

View File

@ -0,0 +1,3 @@
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
__all__ = ["ChatGoogleGenerativeAI"]

View File

@ -0,0 +1,632 @@
from __future__ import annotations
import base64
import logging
import os
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from urllib.parse import urlparse
import requests
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
# TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import]
IMAGE_TYPES: Tuple = ()
try:
import PIL
from PIL.Image import Image
IMAGE_TYPES = IMAGE_TYPES + (Image,)
except ImportError:
PIL = None # type: ignore
Image = None # type: ignore
class ChatGoogleGenerativeAIError(Exception):
"""
Custom exception class for errors associated with the `Google GenAI` API.
This exception is raised when there are specific issues related to the
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
message types or roles.
"""
def _create_retry_decorator() -> Callable[[Any], Any]:
"""
Creates and returns a preconfigured tenacity retry decorator.
The retry decorator is configured to handle specific Google API exceptions
such as ResourceExhausted and ServiceUnavailable. It uses an exponential
backoff strategy for retries.
Returns:
Callable[[Any], Any]: A retry decorator configured for handling specific
Google API exceptions.
"""
import google.api_core.exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
This function is a wrapper that applies a retry mechanism to a provided
chat generation function. It is useful for handling intermittent issues
like network errors or temporary service unavailability.
Args:
generation_method (Callable): The chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
from google.api_core.exceptions import InvalidArgument # type: ignore
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
try:
return generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e
return _chat_with_retry(**kwargs)
async def achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
This function is a wrapper that applies a retry mechanism to a provided
chat generation function. It is useful for handling intermittent issues
like network errors or temporary service unavailability.
Args:
generation_method (Callable): The chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
from google.api_core.exceptions import InvalidArgument # type: ignore
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
try:
return await generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e
return await _achat_with_retry(**kwargs)
def _get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
if message.role not in ("user", "model"):
raise ChatGoogleGenerativeAIError(
"Gemini only supports user and model roles when"
" providing it with Chat messages."
)
return message.role
elif isinstance(message, HumanMessage):
return "user"
elif isinstance(message, AIMessage):
return "model"
else:
# TODO: Gemini doesn't seem to have a concept of system messages yet.
raise ChatGoogleGenerativeAIError(
f"Message of '{message.type}' type not supported by Gemini."
" Please only provide it with Human or AI (user/assistant) messages."
)
def _is_openai_parts_format(part: dict) -> bool:
return "type" in part
def _is_vision_model(model: str) -> bool:
return "vision" in model
def _is_url(s: str) -> bool:
try:
result = urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _is_b64(s: str) -> bool:
return s.startswith("data:image")
def _load_image_from_gcs(path: str, project: Optional[str] = None) -> Image:
try:
from google.cloud import storage # type: ignore[attr-defined]
except ImportError:
raise ImportError(
"google-cloud-storage is required to load images from GCS."
" Install it with `pip install google-cloud-storage`"
)
if PIL is None:
raise ImportError(
"PIL is required to load images. Please install it "
"with `pip install pillow`"
)
gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {path}!")
img_bytes = blobs[0].download_as_bytes()
return PIL.Image.open(BytesIO(img_bytes))
def _url_to_pil(image_source: str) -> Image:
if PIL is None:
raise ImportError(
"PIL is required to load images. Please install it "
"with `pip install pillow`"
)
try:
if isinstance(image_source, IMAGE_TYPES):
return image_source # type: ignore[return-value]
elif _is_url(image_source):
if image_source.startswith("gs://"):
return _load_image_from_gcs(image_source)
response = requests.get(image_source)
response.raise_for_status()
return PIL.Image.open(BytesIO(response.content))
elif _is_b64(image_source):
_, encoded = image_source.split(",", 1)
data = base64.b64decode(encoded)
return PIL.Image.open(BytesIO(data))
elif os.path.exists(image_source):
return PIL.Image.open(image_source)
else:
raise ValueError(
"The provided string is not a valid URL, base64, or file path."
)
except Exception as e:
raise ValueError(f"Unable to process the provided image source: {e}")
def _convert_to_parts(
content: Sequence[Union[str, dict]],
) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts."""
import google.generativeai as genai
parts = []
for part in content:
if isinstance(part, str):
parts.append(genai.types.PartDict(text=part, inline_data=None))
elif isinstance(part, Mapping):
# OpenAI Format
if _is_openai_parts_format(part):
if part["type"] == "text":
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
img_url = part["image_url"]
if isinstance(img_url, dict):
if "url" not in img_url:
raise ValueError(
f"Unrecognized message image format: {img_url}"
)
img_url = img_url["url"]
parts.append({"inline_data": _url_to_pil(img_url)})
else:
raise ValueError(f"Unrecognized message part type: {part['type']}")
else:
# Yolo
logger.warning(
"Unrecognized message part format. Assuming it's a text part."
)
parts.append(part)
else:
# TODO: Maybe some of Google's native stuff
# would hit this branch.
raise ChatGoogleGenerativeAIError(
"Gemini only supports text and inline_data parts."
)
return parts
def _messages_to_genai_contents(
input_messages: Sequence[BaseMessage],
) -> List[genai.types.ContentDict]:
"""Converts a list of messages into a Gemini API google content dicts."""
messages: List[genai.types.MessageDict] = []
for i, message in enumerate(input_messages):
role = _get_role(message)
if isinstance(message.content, str):
parts = [message.content]
else:
parts = _convert_to_parts(message.content)
messages.append({"role": role, "parts": parts})
if i > 0:
# Cannot have multiple messages from the same role in a row.
if role == messages[-2]["role"]:
raise ChatGoogleGenerativeAIError(
"Cannot have multiple messages from the same role in a row."
" Consider merging them into a single message with multiple"
f" parts.\nReceived: {messages}"
)
return messages
def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]:
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
# Simple text response. The typical response
return parts[0].text
elif not parts:
logger.warning("Gemini produced an empty response.")
return ""
messages = []
for part in parts:
if part.text is not None:
messages.append(
{
"type": "text",
"text": part.text,
}
)
else:
# TODO: Handle inline_data if that's a thing?
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
return messages
def _response_to_result(
response: genai.types.GenerateContentResponse,
ai_msg_t: Type[BaseMessage] = AIMessage,
human_msg_t: Type[BaseMessage] = HumanMessage,
chat_msg_t: Type[BaseMessage] = ChatMessage,
generation_t: Type[ChatGeneration] = ChatGeneration,
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {}
if response.prompt_feedback:
try:
prompt_feedback = type(response.prompt_feedback).to_dict(
response.prompt_feedback, use_integers_for_enums=False
)
llm_output["prompt_feedback"] = prompt_feedback
except Exception as e:
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
generations: List[ChatGeneration] = []
role_map = {
"model": ai_msg_t,
"user": human_msg_t,
}
for candidate in response.candidates:
content = candidate.content
parts_content = _parts_to_content(content.parts)
if content.role not in role_map:
logger.warning(
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
)
msg = chat_msg_t(content=parts_content, role=content.role)
else:
msg = role_map[content.role](content=parts_content)
generation_info = {}
if candidate.finish_reason:
generation_info["finish_reason"] = candidate.finish_reason.name
if candidate.safety_ratings:
generation_info["safety_ratings"] = [
type(rating).to_dict(rating) for rating in candidate.safety_ratings
]
generations.append(generation_t(message=msg, generation_info=generation_info))
if not response.candidates:
# Likely a "prompt feedback" violation (e.g., toxic input)
# Raising an error would be different than how OpenAI handles it,
# so we'll just log a warning and continue with an empty message.
logger.warning(
"Gemini produced an empty response. Continuing with empty message\n"
f"Feedback: {response.prompt_feedback}"
)
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
return ChatResult(generations=generations, llm_output=llm_output)
class ChatGoogleGenerativeAI(BaseChatModel):
"""`Google Generative AI` Chat models API.
To use you must have the google.generativeai Python package installed and
either:
1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or
2. Pass your API key using the google_api_key kwarg to the ChatGoogle
constructor.
Example:
.. code-block:: python
from langchain_google_genai import ChatGoogleGenerativeAI
chat = ChatGoogleGenerativeAI(model="gemini-pro")
chat.invoke("Write me a ballad about LangChain")
"""
model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro""",
)
max_output_tokens: int = Field(default=None, description="Max output tokens")
client: Any #: :meta private:
google_api_key: Optional[str] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must by in the closed
interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
top_p: Optional[int] = None
"""The maximum cumulative probability of tokens to consider when sampling.
The model uses combined Top-k and nucleus sampling.
Tokens are sorted based on their assigned probabilities so
that only the most likely tokens are considered. Top-k
sampling directly limits the maximum number of tokens to
consider, while Nucleus sampling limits number of tokens
based on the cumulative probability.
Note: The default value varies by model, see the
`Model.top_p` attribute of the `Model` returned the
`genai.get_model` function.
"""
n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
_generative_model: Any #: :meta private:
class Config:
allow_population_by_field_name = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@property
def _llm_type(self) -> str:
return "chat-google-generative-ai"
@property
def _is_geminiai(self) -> bool:
return self.model is not None and "gemini" in self.model
@classmethod
def is_lc_serializable(self) -> bool:
return True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key)
except ImportError:
raise ChatGoogleGenerativeAIError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`"
)
values["client"] = genai
if (
values.get("temperature") is not None
and not 0 <= values["temperature"] <= 1
):
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values.get("top_p") is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
model = values["model"]
values["_generative_model"] = genai.GenerativeModel(model_name=model)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
"temperature": self.temperature,
"top_k": self.top_k,
"n": self.n,
}
@property
def _generation_method(self) -> Callable:
return self._generative_model.generate_content
@property
def _async_generation_method(self) -> Callable:
return self._generative_model.generate_content_async
def _prepare_params(
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
contents = _messages_to_genai_contents(messages)
gen_config = {
k: v
for k, v in {
"candidate_count": self.n,
"temperature": self.temperature,
"stop_sequences": stop,
"max_output_tokens": self.max_output_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
}.items()
if v is not None
}
if "generation_config" in kwargs:
gen_config = {**gen_config, **kwargs.pop("generation_config")}
params = {"generation_config": gen_config, "contents": contents, **kwargs}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
response: genai.types.GenerateContentResponse = chat_with_retry(
**params,
generation_method=self._generation_method,
)
return _response_to_result(response)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
response: genai.types.GenerateContentResponse = await achat_with_retry(
**params,
generation_method=self._async_generation_method,
)
return _response_to_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
response: genai.types.GenerateContentResponse = chat_with_retry(
**params,
generation_method=self._generation_method,
stream=True,
)
for chunk in response:
_chat_result = _response_to_result(
chunk,
ai_msg_t=AIMessageChunk,
human_msg_t=HumanMessageChunk,
chat_msg_t=ChatMessageChunk,
generation_t=ChatGenerationChunk,
)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
yield gen
if run_manager:
run_manager.on_llm_new_token(gen.text)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
async for chunk in await achat_with_retry(
**params,
generation_method=self._async_generation_method,
stream=True,
):
_chat_result = _response_to_result(
chunk,
ai_msg_t=AIMessageChunk,
human_msg_t=HumanMessageChunk,
chat_msg_t=ChatMessageChunk,
generation_t=ChatGenerationChunk,
)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
yield gen
if run_manager:
await run_manager.on_llm_new_token(gen.text)

1232
libs/partners/google-genai/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,97 @@
[tool.poetry]
name = "langchain-google-genai"
version = "0.0.1"
description = "An integration package connecting Google's genai package and LangChain"
authors = []
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.1"
google-generativeai = "^0.3.1"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = {path = "../../core", develop = true}
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = {path = "../../core", develop = true}
types-requests = "^2.28.11.5"
types-google-cloud-ndb = "^2.2.0.1"
types-pillow = "^10.1.0.2"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = {path = "../../core", develop = true}
pillow = "^10.1.0"
types-requests = "^2.31.0.10"
types-pillow = "^10.1.0.2"
types-google-cloud-ndb = "^2.2.0.1"
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]
[tool.coverage.run]
omit = [
"tests/*",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -0,0 +1,149 @@
"""Test ChatGoogleGenerativeAI chat model."""
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
ChatGoogleGenerativeAIError,
)
_MODEL = "gemini-pro" # TODO: Use nano when it's available.
_VISION_MODEL = "gemini-pro-vision"
_B64_string = """iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABhGlDQ1BJQ0MgUHJvZmlsZQAAeJx9kT1Iw0AcxV8/xCIVQTuIKGSoTi2IijhqFYpQIdQKrTqYXPoFTRqSFBdHwbXg4Mdi1cHFWVcHV0EQ/ABxdXFSdJES/5cUWsR4cNyPd/ced+8Af6PCVDM4DqiaZaSTCSGbWxW6XxHECPoRQ0hipj4niil4jq97+Ph6F+dZ3uf+HL1K3mSATyCeZbphEW8QT29aOud94ggrSQrxOXHMoAsSP3JddvmNc9FhP8+MGJn0PHGEWCh2sNzBrGSoxFPEUUXVKN+fdVnhvMVZrdRY6578heG8trLMdZrDSGIRSxAhQEYNZVRgIU6rRoqJNO0nPPxDjl8kl0yuMhg5FlCFCsnxg//B727NwuSEmxROAF0vtv0xCnTvAs26bX8f23bzBAg8A1da219tADOfpNfbWvQI6NsGLq7bmrwHXO4Ag0+6ZEiOFKDpLxSA9zP6phwwcAv0rLm9tfZx+gBkqKvUDXBwCIwVKXvd492hzt7+PdPq7wdzbXKn5swsVgAAA8lJREFUeJx90dtPHHUUB/Dz+81vZhb2wrDI3soUKBSRcisF21iqqCRNY01NTE0k8aHpi0k18VJfjOFvUF9M44MmGrHFQqSQiKSmFloL5c4CXW6Fhb0vO3ufvczMzweiBGI9+eW8ffI95/yQqqrwv4UxBgCfJ9w/2NfSVB+Nyn6/r+vdLo7H6FkYY6yoABR2PJujj34MSo/d/nHeVLYbydmIp/bEO0fEy/+NMcbTU4/j4Vs6Lr0ccKeYuUKWS4ABVCVHmRdszbfvTgfjR8kz5Jjs+9RREl9Zy2lbVK9wU3/kWLJLCXnqza1bfVe7b9jLbIeTMcYu13Jg/aMiPrCwVFcgtDiMhnxwJ/zXVDwSdVCVMRV7nqzl2i9e/fKrw8mqSp84e2sFj3Oj8/SrF/MaicmyYhAaXu58NPAbeAeyzY0NLecmh2+ODN3BewYBAkAY43giI3kebrnsRmvV9z2D4ciOa3EBAf31Tp9sMgdxMTFm6j74/Ogb70VCYQKAAIDCXkOAIC6pkYBWdwwnpHEdf6L9dJtJKPh95DZhzFKMEWRAGL927XpWTmMA+s8DAOBYAoR483l/iHZ/8bXoODl8b9UfyH72SXepzbyRJNvjFGHKMlhvMBze+cH9+4lEuOOlU2X1tVkFTU7Om03q080NDGXV1cflRpHwaaoiiiildB8jhDLZ7HDfz2Yidba6Vn2L4fhzFrNRKy5OZ2QOZ1U5W8VtqlVH/iUHcM933zZYWS7Wtj66zZr65bzGJQt0glHgudi9XVzEl4vKw2kUPhO020oPYI1qYc+2Xc0bRXFwTLY0VXa2VibD/lBaIXm1UChN5JSRUcQQ1Tk/47Cf3x8bY7y17Y17PVYTG1UkLPBFcqik7Zoa9JcLYoHBqHhXNgd6gS1k9EJ1TQ2l9EDy1saErmQ2kGpwGC2MLOtCM8nZEV1K0tKJtEksSm26J/rHg2zzmabKisq939nHzqUH7efzd4f/nPGW6NP8ybNFrOsWQhpoCuuhnJ4hAnPhFam01K4oQMjBg/mzBjVhuvw2O++KKT+BIVxJKzQECBDLF2qu2WTMmCovtDQ1f8iyoGkUADBCCGPsdnvTW2OtFm01VeB06msvdWlpPZU0wJRG85ns84umU3k+VyxeEcWqvYUBAGsUrbvme4be99HFeisP/pwUOIZaOqQX31ISgrKmZhLHtXNXuJq68orrr5/9mBCglCLAGGPyy81votEbcjlKLrC9E8mhH3wdHRdcyyvjidSlxjftPJpD+o25JYvRHGFoZDdks1mBQhxJu9uxvwEiXuHnHbLd1AAAAABJRU5ErkJggg==""" # noqa: E501
def test_chat_google_genai_stream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_chat_google_genai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_chat_google_genai_abatch() -> None:
"""Test streaming tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_chat_google_genai_abatch_tags() -> None:
"""Test batch tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_chat_google_genai_batch() -> None:
"""Test batch tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_chat_google_genai_ainvoke() -> None:
"""Test invoke tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_chat_google_genai_invoke() -> None:
"""Test invoke tokens from ChatGoogleGenerativeAI."""
llm = ChatGoogleGenerativeAI(model=_MODEL)
result = llm.invoke(
"I'm Pickle Rick",
config=dict(tags=["foo"]),
generation_config=dict(top_k=2, top_p=1, temperature=0.7),
)
assert isinstance(result.content, str)
assert not result.content.startswith(" ")
def test_chat_google_genai_invoke_multimodal() -> None:
messages: list = [
HumanMessage(
content=[
{
"type": "text",
"text": "Guess what's in this picture! You have 3 guesses.",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
response = llm.invoke(messages)
assert isinstance(response.content, str)
assert len(response.content.strip()) > 0
# Try streaming
for chunk in llm.stream(messages):
print(chunk)
assert isinstance(chunk.content, str)
assert len(chunk.content.strip()) > 0
def test_chat_google_genai_invoke_multimodal_too_many_messages() -> None:
# Only supports 1 turn...
messages: list = [
HumanMessage(content="Hi there"),
AIMessage(content="Hi, how are you?"),
HumanMessage(
content=[
{
"type": "text",
"text": "I'm doing great! Guess what's in this picture!",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_VISION_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)
def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
# need the vision model to support this.
messages: list = [
HumanMessage(
content=[
{
"type": "text",
"text": "I'm doing great! Guess what's in this picture!",
},
{
"type": "image_url",
"image_url": "data:image/png;base64," + _B64_string,
},
]
),
]
llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)

View File

@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -0,0 +1,24 @@
"""Test chat model integration."""
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
def test_integration_initialization() -> None:
"""Test chat model initialization."""
ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
top_k=2,
top_p=1,
temperature=0.7,
n=2,
)
ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
top_k=2,
top_p=1,
temperature=0.7,
candidate_count=2,
)

View File

@ -0,0 +1,9 @@
from langchain_google_genai import __all__
EXPECTED_ALL = [
"ChatGoogleGenerativeAI",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)