google-genai: added logic for method get_num_tokens() (#16205)

<!-- Thank you for contributing to LangChain!

Please title your PR "partners: google-genai",

Replace this entire comment with:
- **Description:** : added logic for method get_num_tokens() for
ChatGoogleGenerativeAI , GoogleGenerativeAI,
  - **Issue:** : https://github.com/langchain-ai/langchain/issues/16204,
  - **Dependencies:** : None,
  - **Twitter handle:** @Aditya_Rane

---------

Co-authored-by: adityarane@google.com <adityarane@google.com>
Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
This commit is contained in:
Aditya 2024-01-25 10:13:16 +05:30 committed by GitHub
parent 0785432e7b
commit 9dd7cbb447
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 89 additions and 73 deletions

View File

@ -42,7 +42,7 @@ from langchain_core.messages import (
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env
from tenacity import (
before_sleep_log,
@ -53,6 +53,7 @@ from tenacity import (
)
from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
IMAGE_TYPES: Tuple = ()
try:
@ -417,7 +418,7 @@ def _response_to_result(
return ChatResult(generations=generations, llm_output=llm_output)
class ChatGoogleGenerativeAI(BaseChatModel):
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
"""`Google Generative AI` Chat models API.
To use, you must have either:
@ -435,53 +436,13 @@ class ChatGoogleGenerativeAI(BaseChatModel):
"""
model: str = Field(
...,
description="""The name of the model to use.
Supported examples:
- gemini-pro""",
)
max_output_tokens: int = Field(default=None, description="Max output tokens")
client: Any #: :meta private:
google_api_key: Optional[SecretStr] = None
temperature: Optional[float] = None
"""Run inference with this temperature. Must by in the closed
interval [0.0, 1.0]."""
top_k: Optional[int] = None
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
Must be positive."""
top_p: Optional[float] = None
"""The maximum cumulative probability of tokens to consider when sampling.
The model uses combined Top-k and nucleus sampling.
Tokens are sorted based on their assigned probabilities so
that only the most likely tokens are considered. Top-k
sampling directly limits the maximum number of tokens to
consider, while Nucleus sampling limits number of tokens
based on the cumulative probability.
Note: The default value varies by model, see the
`Model.top_p` attribute of the `Model` returned the
`genai.get_model` function.
"""
n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""
client_options: Optional[Dict] = Field(
None,
description="Client options to pass to the Google API client.",
)
transport: Optional[str] = Field(
None,
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
class Config:
allow_population_by_field_name = True
@ -494,10 +455,6 @@ Supported examples:
def _llm_type(self) -> str:
return "chat-google-generative-ai"
@property
def _is_geminiai(self) -> bool:
return self.model is not None and "gemini" in self.model
@classmethod
def is_lc_serializable(self) -> bool:
return True
@ -658,3 +615,23 @@ Supported examples:
message = history.pop()
chat = self.client.start_chat(history=history)
return params, chat, message
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]
return token_count

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from enum import Enum, auto
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import google.api_core
@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
from langchain_core.utils import get_from_dict_or_env
class GoogleModelFamily(str, Enum):
GEMINI = auto()
PALM = auto()
@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
if "gemini" in value.lower():
return GoogleModelFamily.GEMINI
elif "text-bison" in value.lower():
return GoogleModelFamily.PALM
return None
def _create_retry_decorator(
llm: BaseLLM,
*,
@ -75,10 +89,6 @@ def _completion_with_retry(
)
def _is_gemini_model(model_name: str) -> bool:
return "gemini" in model_name
def _strip_erroneous_leading_spaces(text: str) -> str:
"""Strip erroneous leading spaces from text.
@ -92,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
return text
class GoogleGenerativeAI(BaseLLM, BaseModel):
"""Google GenerativeAI models.
class _BaseGoogleGenerativeAI(BaseModel):
"""Base class for Google Generative AI LLMs"""
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
client: Any #: :meta private:
model: str = Field(
...,
description="""The name of the model to use.
@ -141,15 +143,27 @@ Supported examples:
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
)
@property
def is_gemini(self) -> bool:
"""Returns whether a model is belongs to a Gemini family or not."""
return _is_gemini_model(self.model)
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
@property
def _model_family(self) -> str:
return GoogleModelFamily(self.model)
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
"""Google GenerativeAI models.
Example:
.. code-block:: python
from langchain_google_genai import GoogleGenerativeAI
llm = GoogleGenerativeAI(model="gemini-pro")
"""
client: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates params and passes them to google-generativeai package."""
@ -167,7 +181,7 @@ Supported examples:
client_options=values.get("client_options"),
)
if _is_gemini_model(model_name):
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
values["client"] = genai.GenerativeModel(model_name=model_name)
else:
values["client"] = genai
@ -203,7 +217,7 @@ Supported examples:
"candidate_count": self.n,
}
for prompt in prompts:
if self.is_gemini:
if self._model_family == GoogleModelFamily.GEMINI:
res = _completion_with_retry(
self,
prompt=prompt,
@ -279,7 +293,11 @@ Supported examples:
Returns:
The integer number of tokens in the text.
"""
if self.is_gemini:
raise ValueError("Counting tokens is not yet supported!")
result = self.client.count_text_tokens(model=self.model, prompt=text)
return result["token_count"]
if self._model_family == GoogleModelFamily.GEMINI:
result = self.client.count_tokens(text)
token_count = result.total_tokens
else:
result = self.client.count_text_tokens(model=self.model, prompt=text)
token_count = result["token_count"]
return token_count

View File

@ -280,12 +280,12 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "google-generativeai"
version = "0.3.1"
version = "0.3.2"
description = "Google Generative AI High level API client library and tools."
optional = false
python-versions = ">=3.9"
files = [
{file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
{file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"},
]
[package.dependencies]
@ -294,6 +294,7 @@ google-api-core = "*"
google-auth = "*"
protobuf = "*"
tqdm = "*"
typing-extensions = "*"
[package.extras]
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]

View File

@ -186,3 +186,9 @@ def test_chat_google_genai_system_message() -> None:
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_generativeai_get_num_tokens_gemini() -> None:
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4

View File

@ -60,3 +60,9 @@ def test_generativeai_stream() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
def test_generativeai_get_num_tokens_gemini() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.get_num_tokens("How are you?")
assert output == 4

View File

@ -0,0 +1,8 @@
from langchain_google_genai.llms import GoogleModelFamily
def test_model_family() -> None:
model = GoogleModelFamily("gemini-pro")
assert model == GoogleModelFamily.GEMINI
model = GoogleModelFamily("gemini-ultra")
assert model == GoogleModelFamily.GEMINI