diff --git a/libs/langchain/langchain/llms/google_palm.py b/libs/langchain/langchain/llms/google_palm.py index 1e40143c897..16491df1f90 100644 --- a/libs/langchain/langchain/llms/google_palm.py +++ b/libs/langchain/langchain/llms/google_palm.py @@ -1,62 +1,32 @@ from __future__ import annotations -import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from langchain_core.outputs import Generation, LLMResult from langchain_core.pydantic_v1 import BaseModel, root_validator -from tenacity import ( - before_sleep_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import BaseLLM +from langchain.utilities.vertexai import create_retry_decorator from langchain.utils import get_from_dict_or_env -logger = logging.getLogger(__name__) - -def _create_retry_decorator() -> Callable[[Any], Any]: - """Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions""" - try: - import google.api_core.exceptions - except ImportError: - raise ImportError( - "Could not import google-api-core python package. " - "Please install it with `pip install google-api-core`." - ) - - multiplier = 2 - min_seconds = 1 - max_seconds = 60 - max_retries = 10 - - return retry( - reraise=True, - stop=stop_after_attempt(max_retries), - wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), - retry=( - retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) - | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable) - | retry_if_exception_type(google.api_core.exceptions.GoogleAPIError) - ), - before_sleep=before_sleep_log(logger, logging.WARNING), +def completion_with_retry( + llm: GooglePalm, + *args: Any, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = create_retry_decorator( + llm, max_retries=llm.max_retries, run_manager=run_manager ) - -def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator() - @retry_decorator - def _generate_with_retry(**kwargs: Any) -> Any: - return llm.client.generate_text(**kwargs) + def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: + return llm.client.generate_text(*args, **kwargs) - return _generate_with_retry(**kwargs) + return _completion_with_retry(*args, **kwargs) def _strip_erroneous_leading_spaces(text: str) -> str: @@ -94,6 +64,8 @@ class GooglePalm(BaseLLM, BaseModel): n: int = 1 """Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.""" + max_retries: int = 6 + """The maximum number of retries to make when generating.""" @property def lc_secrets(self) -> Dict[str, str]: @@ -144,7 +116,7 @@ class GooglePalm(BaseLLM, BaseModel): ) -> LLMResult: generations = [] for prompt in prompts: - completion = generate_with_retry( + completion = completion_with_retry( self, model=self.model_name, prompt=prompt, @@ -170,3 +142,17 @@ class GooglePalm(BaseLLM, BaseModel): def _llm_type(self) -> str: """Return type of llm.""" return "google_palm" + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text. + + Useful for checking if an input will fit in a model's context window. + + Args: + text: The string input to tokenize. + + Returns: + The integer number of tokens in the text. + """ + result = self.client.count_text_tokens(model=self.model_name, prompt=text) + return result["token_count"] diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index f074f6578d0..0dedeaf0df8 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -4,13 +4,11 @@ from concurrent.futures import Executor, ThreadPoolExecutor from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, Dict, Iterator, List, Optional, - Union, ) from langchain_core.outputs import Generation, GenerationChunk, LLMResult @@ -20,8 +18,9 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import BaseLLM, create_base_retry_decorator +from langchain.llms.base import BaseLLM from langchain.utilities.vertexai import ( + create_retry_decorator, get_client_info, init_vertexai, raise_vertex_import_error, @@ -65,27 +64,6 @@ def is_codey_model(model_name: str) -> bool: return "code" in model_name -def _create_retry_decorator( - llm: VertexAI, - *, - run_manager: Optional[ - Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] - ] = None, -) -> Callable[[Any], Any]: - import google.api_core - - errors = [ - google.api_core.exceptions.ResourceExhausted, - google.api_core.exceptions.ServiceUnavailable, - google.api_core.exceptions.Aborted, - google.api_core.exceptions.DeadlineExceeded, - ] - decorator = create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries, run_manager=run_manager - ) - return decorator - - def completion_with_retry( llm: VertexAI, *args: Any, @@ -93,7 +71,7 @@ def completion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = create_retry_decorator(llm, run_manager=run_manager) @retry_decorator def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: @@ -109,7 +87,9 @@ def stream_completion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = create_retry_decorator( + llm, max_retries=llm.max_retries, run_manager=run_manager + ) @retry_decorator def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: @@ -125,7 +105,7 @@ async def acompletion_with_retry( **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + retry_decorator = create_retry_decorator(llm, run_manager=run_manager) @retry_decorator async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any: diff --git a/libs/langchain/langchain/utilities/vertexai.py b/libs/langchain/langchain/utilities/vertexai.py index a60d5eea39a..b0c96447903 100644 --- a/libs/langchain/langchain/utilities/vertexai.py +++ b/libs/langchain/langchain/utilities/vertexai.py @@ -1,12 +1,43 @@ """Utilities to init Vertex AI.""" from importlib import metadata -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union + +from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) if TYPE_CHECKING: from google.api_core.gapic_v1.client_info import ClientInfo from google.auth.credentials import Credentials +def create_retry_decorator( + llm: BaseLLM, + *, + max_retries: int = 1, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + """Creates a retry decorator for Vertex / Palm LLMs.""" + import google.api_core + + errors = [ + google.api_core.exceptions.ResourceExhausted, + google.api_core.exceptions.ServiceUnavailable, + google.api_core.exceptions.Aborted, + google.api_core.exceptions.DeadlineExceeded, + google.api_core.exceptions.GoogleAPIError, + ] + decorator = create_base_retry_decorator( + error_types=errors, max_retries=max_retries, run_manager=run_manager + ) + return decorator + + def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None: """Raise ImportError related to Vertex SDK being not available. diff --git a/libs/langchain/tests/integration_tests/llms/test_google_palm.py b/libs/langchain/tests/integration_tests/llms/test_google_palm.py index ca02b185f0d..b82895ba258 100644 --- a/libs/langchain/tests/integration_tests/llms/test_google_palm.py +++ b/libs/langchain/tests/integration_tests/llms/test_google_palm.py @@ -6,6 +6,8 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to from pathlib import Path +from langchain_core.outputs import LLMResult + from langchain.llms.google_palm import GooglePalm from langchain.llms.loading import load_llm @@ -15,6 +17,22 @@ def test_google_palm_call() -> None: llm = GooglePalm(max_output_tokens=10) output = llm("Say foo:") assert isinstance(output, str) + assert llm._llm_type == "google_palm" + assert llm.model_name == "models/text-bison-001" + + +def test_google_palm_generate() -> None: + llm = GooglePalm(temperature=0.3, n=2) + output = llm.generate(["Say foo:"]) + assert isinstance(output, LLMResult) + assert len(output.generations) == 1 + assert len(output.generations[0]) == 2 + + +def test_google_palm_get_num_tokens() -> None: + llm = GooglePalm() + output = llm.get_num_tokens("How are you?") + assert output == 4 def test_saving_loading_llm(tmp_path: Path) -> None: