mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
google-genai: added logic for method get_num_tokens() (#16205)
<!-- Thank you for contributing to LangChain! Please title your PR "partners: google-genai", Replace this entire comment with: - **Description:** : added logic for method get_num_tokens() for ChatGoogleGenerativeAI , GoogleGenerativeAI, - **Issue:** : https://github.com/langchain-ai/langchain/issues/16204, - **Dependencies:** : None, - **Twitter handle:** @Aditya_Rane --------- Co-authored-by: adityarane@google.com <adityarane@google.com> Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
This commit is contained in:
parent
0785432e7b
commit
9dd7cbb447
@ -42,7 +42,7 @@ from langchain_core.messages import (
|
|||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
@ -53,6 +53,7 @@ from tenacity import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from langchain_google_genai._common import GoogleGenerativeAIError
|
from langchain_google_genai._common import GoogleGenerativeAIError
|
||||||
|
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
|
||||||
|
|
||||||
IMAGE_TYPES: Tuple = ()
|
IMAGE_TYPES: Tuple = ()
|
||||||
try:
|
try:
|
||||||
@ -417,7 +418,7 @@ def _response_to_result(
|
|||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
|
||||||
class ChatGoogleGenerativeAI(BaseChatModel):
|
class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||||
"""`Google Generative AI` Chat models API.
|
"""`Google Generative AI` Chat models API.
|
||||||
|
|
||||||
To use, you must have either:
|
To use, you must have either:
|
||||||
@ -435,53 +436,13 @@ class ChatGoogleGenerativeAI(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = Field(
|
|
||||||
...,
|
|
||||||
description="""The name of the model to use.
|
|
||||||
Supported examples:
|
|
||||||
- gemini-pro""",
|
|
||||||
)
|
|
||||||
max_output_tokens: int = Field(default=None, description="Max output tokens")
|
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
google_api_key: Optional[SecretStr] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
"""Run inference with this temperature. Must by in the closed
|
|
||||||
interval [0.0, 1.0]."""
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
"""Decode using top-k sampling: consider the set of top_k most probable tokens.
|
|
||||||
Must be positive."""
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
"""The maximum cumulative probability of tokens to consider when sampling.
|
|
||||||
|
|
||||||
The model uses combined Top-k and nucleus sampling.
|
|
||||||
|
|
||||||
Tokens are sorted based on their assigned probabilities so
|
|
||||||
that only the most likely tokens are considered. Top-k
|
|
||||||
sampling directly limits the maximum number of tokens to
|
|
||||||
consider, while Nucleus sampling limits number of tokens
|
|
||||||
based on the cumulative probability.
|
|
||||||
|
|
||||||
Note: The default value varies by model, see the
|
|
||||||
`Model.top_p` attribute of the `Model` returned the
|
|
||||||
`genai.get_model` function.
|
|
||||||
"""
|
|
||||||
n: int = Field(default=1, alias="candidate_count")
|
|
||||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
|
||||||
not return the full n completions if duplicates are generated."""
|
|
||||||
convert_system_message_to_human: bool = False
|
convert_system_message_to_human: bool = False
|
||||||
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
"""Whether to merge any leading SystemMessage into the following HumanMessage.
|
||||||
|
|
||||||
Gemini does not support system messages; any unsupported messages will
|
Gemini does not support system messages; any unsupported messages will
|
||||||
raise an error."""
|
raise an error."""
|
||||||
client_options: Optional[Dict] = Field(
|
|
||||||
None,
|
|
||||||
description="Client options to pass to the Google API client.",
|
|
||||||
)
|
|
||||||
transport: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
@ -494,10 +455,6 @@ Supported examples:
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "chat-google-generative-ai"
|
return "chat-google-generative-ai"
|
||||||
|
|
||||||
@property
|
|
||||||
def _is_geminiai(self) -> bool:
|
|
||||||
return self.model is not None and "gemini" in self.model
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(self) -> bool:
|
def is_lc_serializable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -658,3 +615,23 @@ Supported examples:
|
|||||||
message = history.pop()
|
message = history.pop()
|
||||||
chat = self.client.start_chat(history=history)
|
chat = self.client.start_chat(history=history)
|
||||||
return params, chat, message
|
return params, chat, message
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Get the number of tokens present in the text.
|
||||||
|
|
||||||
|
Useful for checking if an input will fit in a model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The string input to tokenize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The integer number of tokens in the text.
|
||||||
|
"""
|
||||||
|
if self._model_family == GoogleModelFamily.GEMINI:
|
||||||
|
result = self.client.count_tokens(text)
|
||||||
|
token_count = result.total_tokens
|
||||||
|
else:
|
||||||
|
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||||
|
token_count = result["token_count"]
|
||||||
|
|
||||||
|
return token_count
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import google.api_core
|
import google.api_core
|
||||||
@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
|
|||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleModelFamily(str, Enum):
|
||||||
|
GEMINI = auto()
|
||||||
|
PALM = auto()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
|
||||||
|
if "gemini" in value.lower():
|
||||||
|
return GoogleModelFamily.GEMINI
|
||||||
|
elif "text-bison" in value.lower():
|
||||||
|
return GoogleModelFamily.PALM
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(
|
def _create_retry_decorator(
|
||||||
llm: BaseLLM,
|
llm: BaseLLM,
|
||||||
*,
|
*,
|
||||||
@ -75,10 +89,6 @@ def _completion_with_retry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_gemini_model(model_name: str) -> bool:
|
|
||||||
return "gemini" in model_name
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
"""Strip erroneous leading spaces from text.
|
"""Strip erroneous leading spaces from text.
|
||||||
|
|
||||||
@ -92,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
class GoogleGenerativeAI(BaseLLM, BaseModel):
|
class _BaseGoogleGenerativeAI(BaseModel):
|
||||||
"""Google GenerativeAI models.
|
"""Base class for Google Generative AI LLMs"""
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_google_genai import GoogleGenerativeAI
|
|
||||||
llm = GoogleGenerativeAI(model="gemini-pro")
|
|
||||||
"""
|
|
||||||
|
|
||||||
client: Any #: :meta private:
|
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
...,
|
...,
|
||||||
description="""The name of the model to use.
|
description="""The name of the model to use.
|
||||||
@ -141,15 +143,27 @@ Supported examples:
|
|||||||
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def is_gemini(self) -> bool:
|
|
||||||
"""Returns whether a model is belongs to a Gemini family or not."""
|
|
||||||
return _is_gemini_model(self.model)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"google_api_key": "GOOGLE_API_KEY"}
|
return {"google_api_key": "GOOGLE_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_family(self) -> str:
|
||||||
|
return GoogleModelFamily(self.model)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
|
||||||
|
"""Google GenerativeAI models.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_google_genai import GoogleGenerativeAI
|
||||||
|
llm = GoogleGenerativeAI(model="gemini-pro")
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validates params and passes them to google-generativeai package."""
|
"""Validates params and passes them to google-generativeai package."""
|
||||||
@ -167,7 +181,7 @@ Supported examples:
|
|||||||
client_options=values.get("client_options"),
|
client_options=values.get("client_options"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_gemini_model(model_name):
|
if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
|
||||||
values["client"] = genai.GenerativeModel(model_name=model_name)
|
values["client"] = genai.GenerativeModel(model_name=model_name)
|
||||||
else:
|
else:
|
||||||
values["client"] = genai
|
values["client"] = genai
|
||||||
@ -203,7 +217,7 @@ Supported examples:
|
|||||||
"candidate_count": self.n,
|
"candidate_count": self.n,
|
||||||
}
|
}
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
if self.is_gemini:
|
if self._model_family == GoogleModelFamily.GEMINI:
|
||||||
res = _completion_with_retry(
|
res = _completion_with_retry(
|
||||||
self,
|
self,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -279,7 +293,11 @@ Supported examples:
|
|||||||
Returns:
|
Returns:
|
||||||
The integer number of tokens in the text.
|
The integer number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
if self.is_gemini:
|
if self._model_family == GoogleModelFamily.GEMINI:
|
||||||
raise ValueError("Counting tokens is not yet supported!")
|
result = self.client.count_tokens(text)
|
||||||
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
token_count = result.total_tokens
|
||||||
return result["token_count"]
|
else:
|
||||||
|
result = self.client.count_text_tokens(model=self.model, prompt=text)
|
||||||
|
token_count = result["token_count"]
|
||||||
|
|
||||||
|
return token_count
|
||||||
|
5
libs/partners/google-genai/poetry.lock
generated
5
libs/partners/google-genai/poetry.lock
generated
@ -280,12 +280,12 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "google-generativeai"
|
name = "google-generativeai"
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
description = "Google Generative AI High level API client library and tools."
|
description = "Google Generative AI High level API client library and tools."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "google_generativeai-0.3.1-py3-none-any.whl", hash = "sha256:800ec6041ca537b897d7ba654f4125651c64b38506f2bfce3b464370e3333a1b"},
|
{file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -294,6 +294,7 @@ google-api-core = "*"
|
|||||||
google-auth = "*"
|
google-auth = "*"
|
||||||
protobuf = "*"
|
protobuf = "*"
|
||||||
tqdm = "*"
|
tqdm = "*"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
|
dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"]
|
||||||
|
@ -186,3 +186,9 @@ def test_chat_google_genai_system_message() -> None:
|
|||||||
response = model([system_message, message1, message2, message3])
|
response = model([system_message, message1, message2, message3])
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generativeai_get_num_tokens_gemini() -> None:
|
||||||
|
llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||||
|
output = llm.get_num_tokens("How are you?")
|
||||||
|
assert output == 4
|
||||||
|
@ -60,3 +60,9 @@ def test_generativeai_stream() -> None:
|
|||||||
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||||
outputs = list(llm.stream("Please say foo:"))
|
outputs = list(llm.stream("Please say foo:"))
|
||||||
assert isinstance(outputs[0], str)
|
assert isinstance(outputs[0], str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generativeai_get_num_tokens_gemini() -> None:
|
||||||
|
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
|
||||||
|
output = llm.get_num_tokens("How are you?")
|
||||||
|
assert output == 4
|
||||||
|
8
libs/partners/google-genai/tests/unit_tests/test_llms.py
Normal file
8
libs/partners/google-genai/tests/unit_tests/test_llms.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from langchain_google_genai.llms import GoogleModelFamily
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_family() -> None:
|
||||||
|
model = GoogleModelFamily("gemini-pro")
|
||||||
|
assert model == GoogleModelFamily.GEMINI
|
||||||
|
model = GoogleModelFamily("gemini-ultra")
|
||||||
|
assert model == GoogleModelFamily.GEMINI
|
Loading…
Reference in New Issue
Block a user