google-genai[patch], community[patch]: Added support for new Google GenerativeAI models (#14530)

Replace this entire comment with:
  - **Description:** added support for new Google GenerativeAI models
  - **Twitter handle:** lkuligin

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Leonid Kuligin
2023-12-15 05:56:46 +01:00
committed by GitHub
parent 6bbf0797f7
commit 7f42811e14
7 changed files with 791 additions and 42 deletions

View File

@@ -1,10 +1,12 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterator, List, Optional
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.language_models import LanguageModelInput
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms import BaseLLM
@@ -13,7 +15,9 @@ from langchain_community.utilities.vertexai import create_retry_decorator
def completion_with_retry(
llm: GooglePalm,
*args: Any,
prompt: LanguageModelInput,
is_gemini: bool = False,
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
@@ -23,10 +27,23 @@ def completion_with_retry(
)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.generate_text(*args, **kwargs)
def _completion_with_retry(
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
) -> Any:
generation_config = kwargs.get("generation_config", {})
if is_gemini:
return llm.client.generate_content(
contents=prompt, stream=stream, generation_config=generation_config
)
return llm.client.generate_text(prompt=prompt, **kwargs)
return _completion_with_retry(*args, **kwargs)
return _completion_with_retry(
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
)
def _is_gemini_model(model_name: str) -> bool:
return "gemini" in model_name
def _strip_erroneous_leading_spaces(text: str) -> str:
@@ -42,11 +59,16 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
return text
@deprecated("0.0.351", alternative="langchain_google_genai.GoogleGenerativeAI")
class GooglePalm(BaseLLM, BaseModel):
"""Google PaLM models."""
"""
DEPRECATED: Use `langchain_google_genai.GoogleGenerativeAI` instead.
Google PaLM models.
"""
client: Any #: :meta private:
google_api_key: Optional[str]
google_api_key: Optional[SecretStr]
model_name: str = "models/text-bison-001"
"""Model name to use."""
temperature: float = 0.7
@@ -67,6 +89,11 @@ class GooglePalm(BaseLLM, BaseModel):
max_retries: int = 6
"""The maximum number of retries to make when generating."""
@property
def is_gemini(self) -> bool:
"""Returns whether a model is belongs to a Gemini family or not."""
return _is_gemini_model(self.model_name)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@@ -86,18 +113,25 @@ class GooglePalm(BaseLLM, BaseModel):
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
model_name = values["model_name"]
try:
import google.generativeai as genai
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(api_key=google_api_key)
if _is_gemini_model(model_name):
values["client"] = genai.GenerativeModel(model_name=model_name)
else:
values["client"] = genai
except ImportError:
raise ImportError(
"Could not import google-generativeai python package. "
"Please install it with `pip install google-generativeai`."
)
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
@@ -119,30 +153,76 @@ class GooglePalm(BaseLLM, BaseModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = []
generations: List[List[Generation]] = []
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
for prompt in prompts:
completion = completion_with_retry(
self,
model=self.model_name,
prompt=prompt,
stop_sequences=stop,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
candidate_count=self.n,
**kwargs,
)
prompt_generations = []
for candidate in completion.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
if self.is_gemini:
res = completion_with_retry(
self,
prompt=prompt,
stream=False,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
]
generations.append([Generation(text=c) for c in candidates])
else:
res = completion_with_retry(
self,
model=self.model_name,
prompt=prompt,
stream=False,
is_gemini=False,
run_manager=run_manager,
**generation_config,
)
prompt_generations = []
for candidate in res.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generation_config = kwargs.get("generation_config", {})
if stop:
generation_config["stop_sequences"] = stop
for stream_resp in completion_with_retry(
self,
prompt,
stream=True,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
stream_resp.text,
chunk=chunk,
verbose=self.verbose,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
@@ -159,5 +239,7 @@ class GooglePalm(BaseLLM, BaseModel):
Returns:
The integer number of tokens in the text.
"""
if self.is_gemini:
raise ValueError("Counting tokens is not yet supported!")
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
return result["token_count"]

View File

@@ -1,4 +1,4 @@
"""Test Google PaLM Text API wrapper.
"""Test Google GenerativeAI API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
@@ -6,35 +6,68 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
from pathlib import Path
import pytest
from langchain_core.outputs import LLMResult
from langchain_community.llms.google_palm import GooglePalm
from langchain_community.llms.loading import load_llm
model_names = [None, "models/text-bison-001", "gemini-pro"]
def test_google_palm_call() -> None:
"""Test valid call to Google PaLM text API."""
llm = GooglePalm(max_output_tokens=10)
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GooglePalm(max_output_tokens=10, model_name=model_name)
else:
llm = GooglePalm(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
assert llm.model_name == "models/text-bison-001"
if model_name and "gemini" in model_name:
assert llm.client.model_name == "models/gemini-pro"
else:
assert llm.model_name == "models/text-bison-001"
def test_google_palm_generate() -> None:
llm = GooglePalm(temperature=0.3, n=2)
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_generate(model_name: str) -> None:
n = 1 if model_name == "gemini-pro" else 2
if model_name:
llm = GooglePalm(temperature=0.3, n=n, model_name=model_name)
else:
llm = GooglePalm(temperature=0.3, n=n)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
assert len(output.generations[0]) == n
def test_google_palm_get_num_tokens() -> None:
def test_google_generativeai_get_num_tokens() -> None:
llm = GooglePalm()
output = llm.get_num_tokens("How are you?")
assert output == 4
async def test_google_generativeai_agenerate() -> None:
llm = GooglePalm(temperature=0, model_name="gemini-pro")
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_generativeai_stream() -> None:
llm = GooglePalm(temperature=0, model_name="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading a Google PaLM LLM."""
llm = GooglePalm(max_output_tokens=10)

View File

@@ -6,11 +6,14 @@ This module integrates Google's Generative AI models, specifically the Gemini se
The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications.
**LLMs**
The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model.
**Embeddings**
The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models.
These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more.
**Installation**
To install the package, use pip:
@@ -29,6 +32,17 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro")
llm.invoke("Sing a ballad of LangChain.")
```
## Using LLMs
The package also supports generating text with Google's models.
```python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
llm.invoke("Once upon a time, a library called LangChain")
```
## Embedding Generation
The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications.
@@ -42,5 +56,10 @@ embeddings.embed_query("hello, world!")
""" # noqa: E501
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
__all__ = ["ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings"]
__all__ = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
]

View File

@@ -0,0 +1,262 @@
from __future__ import annotations
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import google.api_core
import google.generativeai as genai # type: ignore[import]
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
def _create_retry_decorator(
llm: BaseLLM,
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def _completion_with_retry(
llm: GoogleGenerativeAI,
prompt: LanguageModelInput,
is_gemini: bool = False,
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(
prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
) -> Any:
generation_config = kwargs.get("generation_config", {})
if is_gemini:
return llm.client.generate_content(
contents=prompt, stream=stream, generation_config=generation_config
)
return llm.client.generate_text(prompt=prompt, **kwargs)
return _completion_with_retry(
prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
)
def _is_gemini_model(model_name: str) -> bool:
return "gemini" in model_name
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
The PaLM API will sometimes erroneously return a single leading space in all
lines > 1. This function strips that space.
"""
has_leading_space = all(not line or line[0] == " " for line in text.split("\n")[1:])
if has_leading_space:
return text.replace("\n ", "\n")
else:
return text
class GoogleGenerativeAI(BaseLLM, BaseModel):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
client: Any #: :meta private:
model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro
- models/text-bison-001""",
)
"""Model name to use."""
google_api_key: Optional[SecretStr] = None
temperature: float = 0.7
"""Run inference with this temperature. Must by in the closed interval
[0.0, 1.0]."""
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
max_output_tokens: Optional[int] = None
"""Maximum number of tokens to include in a candidate. Must be greater than zero.
If unset, will default to 64."""
n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
max_retries: int = 6
"""The maximum number of retries to make when generating."""
@property
def is_gemini(self) -> bool:
"""Returns whether a model is belongs to a Gemini family or not."""
return _is_gemini_model(self.model)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists."""
google_api_key = get_from_dict_or_env(
values, "google_api_key", "GOOGLE_API_KEY"
)
model_name = values["model"]
if isinstance(google_api_key, SecretStr):
google_api_key = google_api_key.get_secret_value()
genai.configure(api_key=google_api_key)
if _is_gemini_model(model_name):
values["client"] = genai.GenerativeModel(model_name=model_name)
else:
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
if values["max_output_tokens"] is not None and values["max_output_tokens"] <= 0:
raise ValueError("max_output_tokens must be greater than zero")
return values
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations: List[List[Generation]] = []
generation_config = {
"stop_sequences": stop,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"max_output_tokens": self.max_output_tokens,
"candidate_count": self.n,
}
for prompt in prompts:
if self.is_gemini:
res = _completion_with_retry(
self,
prompt=prompt,
stream=False,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
)
candidates = [
"".join([p.text for p in c.content.parts]) for c in res.candidates
]
generations.append([Generation(text=c) for c in candidates])
else:
res = _completion_with_retry(
self,
model=self.model,
prompt=prompt,
stream=False,
is_gemini=False,
run_manager=run_manager,
**generation_config,
)
prompt_generations = []
for candidate in res.candidates:
raw_text = candidate["output"]
stripped_text = _strip_erroneous_leading_spaces(raw_text)
prompt_generations.append(Generation(text=stripped_text))
generations.append(prompt_generations)
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generation_config = kwargs.get("generation_config", {})
if stop:
generation_config["stop_sequences"] = stop
for stream_resp in _completion_with_retry(
self,
prompt,
stream=True,
is_gemini=True,
run_manager=run_manager,
generation_config=generation_config,
**kwargs,
):
chunk = GenerationChunk(text=stream_resp.text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
stream_resp.text,
chunk=chunk,
verbose=self.verbose,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "google_palm"
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self.is_gemini:
raise ValueError("Counting tokens is not yet supported!")
result = self.client.count_text_tokens(model=self.model, prompt=text)
return result["token_count"]

View File

@@ -0,0 +1,65 @@
"""Test Google GenerativeAI API wrapper.
Note: This test must be run with the GOOGLE_API_KEY environment variable set to a
valid API key.
"""
import pytest
from langchain_core.outputs import LLMResult
from langchain_google_genai.llms import GoogleGenerativeAI
model_names = [None, "models/text-bison-001", "gemini-pro"]
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_call(model_name: str) -> None:
"""Test valid call to Google GenerativeAI text API."""
if model_name:
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
else:
llm = GoogleGenerativeAI(max_output_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
if model_name and "gemini" in model_name:
assert llm.client.model_name == "models/gemini-pro"
else:
assert llm.model == "models/text-bison-001"
@pytest.mark.parametrize(
"model_name",
model_names,
)
def test_google_generativeai_generate(model_name: str) -> None:
n = 1 if model_name == "gemini-pro" else 2
if model_name:
llm = GoogleGenerativeAI(temperature=0.3, n=n, model=model_name)
else:
llm = GoogleGenerativeAI(temperature=0.3, n=n)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == n
def test_google_generativeai_get_num_tokens() -> None:
llm = GoogleGenerativeAI()
output = llm.get_num_tokens("How are you?")
assert output == 4
async def test_google_generativeai_agenerate() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_generativeai_stream() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)

View File

@@ -3,6 +3,7 @@ from langchain_google_genai import __all__
EXPECTED_ALL = [
"ChatGoogleGenerativeAI",
"GoogleGenerativeAIEmbeddings",
"GoogleGenerativeAI",
]