mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
google-genai: added logic for method get_num_tokens() (#16205)
<!-- Thank you for contributing to LangChain! Please title your PR "partners: google-genai", Replace this entire comment with: - **Description:** : added logic for method get_num_tokens() for ChatGoogleGenerativeAI , GoogleGenerativeAI, - **Issue:** : https://github.com/langchain-ai/langchain/issues/16204, - **Dependencies:** : None, - **Twitter handle:** @Aditya_Rane --------- Co-authored-by: adityarane@google.com <adityarane@google.com> Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
This commit is contained in:
parent
0785432e7b
commit
9dd7cbb447
@ -42,7 +42,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
@ -53,6 +53,7 @@ from tenacity import (
|
||||
)
|
||||
|
||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
|
||||
|
||||
IMAGE_TYPES: Tuple = ()
|
||||
try:
|
||||
@ -417,7 +418,7 @@ def _response_to_result(
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
class ChatGoogleGenerativeAI(BaseChatModel):
|
||||
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
"""`Google Generative AI` Chat models API.
|
||||
|
||||
To use, you must have either:
|
||||
@ -435,53 +436,13 @@ class ChatGoogleGenerativeAI(BaseChatModel):
|
||||
|
||||
"""
|
||||
|
||||
model: str = Field(
|
||||
...,
|
||||
description="""The name of the model to use.
|
||||
Supported examples:
|
||||
- gemini-pro""",
|
||||
)
|
||||
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
||||
|
||||
client: Any #: :meta private:
|
||||
google_api_key: Optional[SecretStr] = None
|
||||
temperature: Optional[float] = None
|
||||
"""Run inference with this temperature. Must by in the closed
|
||||
interval [0.0, 1.0]."""
|
||||
top_k: Optional[int] = None
|
||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
||||
Must be positive."""
|
||||
top_p: Optional[float] = None
|
||||
"""The maximum cumulative probability of tokens to consider when sampling.
|
||||
|
||||
The model uses combined Top-k and nucleus sampling.
|
||||
|
||||
Tokens are sorted based on their assigned probabilities so
|
||||
that only the most likely tokens are considered. Top-k
|
||||
sampling directly limits the maximum number of tokens to
|
||||
consider, while Nucleus sampling limits number of tokens
|
||||
based on the cumulative probability.
|
||||
|
||||
Note: The default value varies by model, see the
|
||||
`Model.top_p` attribute of the `Model` returned the
|
||||
`genai.get_model` function.
|
||||
"""
|
||||
n: int = Field(default=1, alias="candidate_count")
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
convert_system_message_to_human: bool = False
|
||||
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
||||
|
||||
Gemini does not support system messages; any unsupported messages will
|
||||
raise an error."""
|
||||
client_options: Optional[Dict] = Field(
|
||||
None,
|
||||
description="Client options to pass to the Google API client.",
|
||||
)
|
||||
transport: Optional[str] = Field(
|
||||
None,
|
||||
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
||||
)
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
@ -494,10 +455,6 @@ Supported examples:
|
||||
def _llm_type(self) -> str:
|
||||
return "chat-google-generative-ai"
|
||||
|
||||
@property
|
||||
def _is_geminiai(self) -> bool:
|
||||
return self.model is not None and "gemini" in self.model
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(self) -> bool:
|
||||
return True
|
||||
@ -658,3 +615,23 @@ Supported examples:
|
||||
message = history.pop()
|
||||
chat = self.client.start_chat(history=history)
|
||||
return params, chat, message
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
if self._model_family == GoogleModelFamily.GEMINI:
|
||||
result = self.client.count_tokens(text)
|
||||
token_count = result.total_tokens
|
||||
else:
|
||||
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||
token_count = result["token_count"]
|
||||
|
||||
return token_count
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import google.api_core
|
||||
@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class GoogleModelFamily(str, Enum):
|
||||
GEMINI = auto()
|
||||
PALM = auto()
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
|
||||
if "gemini" in value.lower():
|
||||
return GoogleModelFamily.GEMINI
|
||||
elif "text-bison" in value.lower():
|
||||
return GoogleModelFamily.PALM
|
||||
return None
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: BaseLLM,
|
||||
*,
|
||||
@ -75,10 +89,6 @@ def _completion_with_retry(
|
||||
)
|
||||
|
||||
|
||||
def _is_gemini_model(model_name: str) -> bool:
|
||||
return "gemini" in model_name
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
"""Strip erroneous leading spaces from text.
|
||||
|
||||
@ -92,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
class GoogleGenerativeAI(BaseLLM, BaseModel):
|
||||
"""Google GenerativeAI models.
|
||||
class _BaseGoogleGenerativeAI(BaseModel):
|
||||
"""Base class for Google Generative AI LLMs"""
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import GoogleGenerativeAI
|
||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str = Field(
|
||||
...,
|
||||
description="""The name of the model to use.
|
||||
@ -141,15 +143,27 @@ Supported examples:
|
||||
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
||||
)
|
||||
|
||||
@property
|
||||
def is_gemini(self) -> bool:
|
||||
"""Returns whether a model is belongs to a Gemini family or not."""
|
||||
return _is_gemini_model(self.model)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||
|
||||
@property
|
||||
def _model_family(self) -> str:
|
||||
return GoogleModelFamily(self.model)
|
||||
|
||||
|
||||
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
||||
"""Google GenerativeAI models.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_google_genai import GoogleGenerativeAI
|
||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validates params and passes them to google-generativeai package."""
|
||||
@ -167,7 +181,7 @@ Supported examples:
|
||||
client_options=values.get("client_options"),
|
||||
)
|
||||
|
||||
if _is_gemini_model(model_name):
|
||||
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||
else:
|
||||
values["client"] = genai
|
||||
@ -203,7 +217,7 @@ Supported examples:
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
for prompt in prompts:
|
||||
if self.is_gemini:
|
||||
if self._model_family == GoogleModelFamily.GEMINI:
|
||||
res = _completion_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
@ -279,7 +293,11 @@ Supported examples:
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
if self.is_gemini:
|
||||
raise ValueError("Counting tokens is not yet supported!")
|
||||
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||
return result["token_count"]
|
||||
if self._model_family == GoogleModelFamily.GEMINI:
|
||||
result = self.client.count_tokens(text)
|
||||
token_count = result.total_tokens
|
||||
else:
|
||||
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||
token_count = result["token_count"]
|
||||
|
||||
return token_count
|
||||
|
5
libs/partners/google-genai/poetry.lock
generated
5
libs/partners/google-genai/poetry.lock
generated
@ -280,12 +280,12 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
|
||||
|
||||
[[package]]
|
||||
name = "google-generativeai"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
description = "Google Generative AI High level API client library and tools."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
|
||||
{file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -294,6 +294,7 @@ google-api-core = "*"
|
||||
google-auth = "*"
|
||||
protobuf = "*"
|
||||
tqdm = "*"
|
||||
typing-extensions = "*"
|
||||
|
||||
[package.extras]
|
||||
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
|
||||
|
@ -186,3 +186,9 @@ def test_chat_google_genai_system_message() -> None:
|
||||
response = model([system_message, message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_generativeai_get_num_tokens_gemini() -> None:
|
||||
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
@ -60,3 +60,9 @@ def test_generativeai_stream() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
def test_generativeai_get_num_tokens_gemini() -> None:
|
||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
8
libs/partners/google-genai/tests/unit_tests/test_llms.py
Normal file
8
libs/partners/google-genai/tests/unit_tests/test_llms.py
Normal file
@ -0,0 +1,8 @@
|
||||
from langchain_google_genai.llms import GoogleModelFamily
|
||||
|
||||
|
||||
def test_model_family() -> None:
|
||||
model = GoogleModelFamily("gemini-pro")
|
||||
assert model == GoogleModelFamily.GEMINI
|
||||
model = GoogleModelFamily("gemini-ultra")
|
||||
assert model == GoogleModelFamily.GEMINI
|
Loading…
Reference in New Issue
Block a user