added get_num_tokens to GooglePalm (#14282)

added get_num_tokens to GooglePalm + a little bit of refactoring
This commit is contained in:
Leonid Kuligin 2023-12-06 02:24:19 +01:00 committed by GitHub
parent c215a4c9ec
commit fd5be55a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 73 deletions

View File

@ -1,62 +1,32 @@
from __future__ import annotations from __future__ import annotations
import logging from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from langchain_core.outputs import Generation, LLMResult from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.pydantic_v1 import BaseModel, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import BaseLLM from langchain.llms import BaseLLM
from langchain.utilities.vertexai import create_retry_decorator
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def completion_with_retry(
def _create_retry_decorator() -> Callable[[Any], Any]: llm: GooglePalm,
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" *args: Any,
try: run_manager: Optional[CallbackManagerForLLMRun] = None,
import google.api_core.exceptions **kwargs: Any,
except ImportError: ) -> Any:
raise ImportError(
"Could not import google-api-core python package. "
"Please install it with `pip install google-api-core`."
)
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator() retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator @retry_decorator
def _generate_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.generate_text(**kwargs) return llm.client.generate_text(*args, **kwargs)
return _generate_with_retry(**kwargs) return _completion_with_retry(*args, **kwargs)
def _strip_erroneous_leading_spaces(text: str) -> str: def _strip_erroneous_leading_spaces(text: str) -> str:
@ -94,6 +64,8 @@ class GooglePalm(BaseLLM, BaseModel):
n: int = 1 n: int = 1
"""Number of chat completions to generate for each prompt. Note that the API may """Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated.""" not return the full n completions if duplicates are generated."""
max_retries: int = 6
"""The maximum number of retries to make when generating."""
@property @property
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
@ -144,7 +116,7 @@ class GooglePalm(BaseLLM, BaseModel):
) -> LLMResult: ) -> LLMResult:
generations = [] generations = []
for prompt in prompts: for prompt in prompts:
completion = generate_with_retry( completion = completion_with_retry(
self, self,
model=self.model_name, model=self.model_name,
prompt=prompt, prompt=prompt,
@ -170,3 +142,17 @@ class GooglePalm(BaseLLM, BaseModel):
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "google_palm" return "google_palm"
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Args:
text: The string input to tokenize.
Returns:
The integer number of tokens in the text.
"""
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
return result["token_count"]

View File

@ -4,13 +4,11 @@ from concurrent.futures import Executor, ThreadPoolExecutor
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
ClassVar, ClassVar,
Dict, Dict,
Iterator, Iterator,
List, List,
Optional, Optional,
Union,
) )
from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.outputs import Generation, GenerationChunk, LLMResult
@ -20,8 +18,9 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.llms.base import BaseLLM
from langchain.utilities.vertexai import ( from langchain.utilities.vertexai import (
create_retry_decorator,
get_client_info, get_client_info,
init_vertexai, init_vertexai,
raise_vertex_import_error, raise_vertex_import_error,
@ -65,27 +64,6 @@ def is_codey_model(model_name: str) -> bool:
return "code" in model_name return "code" in model_name
def _create_retry_decorator(
llm: VertexAI,
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import google.api_core
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return decorator
def completion_with_retry( def completion_with_retry(
llm: VertexAI, llm: VertexAI,
*args: Any, *args: Any,
@ -93,7 +71,7 @@ def completion_with_retry(
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
@ -109,7 +87,9 @@ def stream_completion_with_retry(
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator @retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
@ -125,7 +105,7 @@ async def acompletion_with_retry(
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any: async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:

View File

@ -1,12 +1,43 @@
"""Utilities to init Vertex AI.""" """Utilities to init Vertex AI."""
from importlib import metadata from importlib import metadata
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from google.api_core.gapic_v1.client_info import ClientInfo from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import Credentials from google.auth.credentials import Credentials
def create_retry_decorator(
llm: BaseLLM,
*,
max_retries: int = 1,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Creates a retry decorator for Vertex / Palm LLMs."""
import google.api_core
errors = [
google.api_core.exceptions.ResourceExhausted,
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.Aborted,
google.api_core.exceptions.DeadlineExceeded,
google.api_core.exceptions.GoogleAPIError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None: def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
"""Raise ImportError related to Vertex SDK being not available. """Raise ImportError related to Vertex SDK being not available.

View File

@ -6,6 +6,8 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
from pathlib import Path from pathlib import Path
from langchain_core.outputs import LLMResult
from langchain.llms.google_palm import GooglePalm from langchain.llms.google_palm import GooglePalm
from langchain.llms.loading import load_llm from langchain.llms.loading import load_llm
@ -15,6 +17,22 @@ def test_google_palm_call() -> None:
llm = GooglePalm(max_output_tokens=10) llm = GooglePalm(max_output_tokens=10)
output = llm("Say foo:") output = llm("Say foo:")
assert isinstance(output, str) assert isinstance(output, str)
assert llm._llm_type == "google_palm"
assert llm.model_name == "models/text-bison-001"
def test_google_palm_generate() -> None:
llm = GooglePalm(temperature=0.3, n=2)
output = llm.generate(["Say foo:"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 1
assert len(output.generations[0]) == 2
def test_google_palm_get_num_tokens() -> None:
llm = GooglePalm()
output = llm.get_num_tokens("How are you?")
assert output == 4
def test_saving_loading_llm(tmp_path: Path) -> None: def test_saving_loading_llm(tmp_path: Path) -> None: