mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
added get_num_tokens to GooglePalm (#14282)
added get_num_tokens to GooglePalm + a little bit of refactoring
This commit is contained in:
parent
c215a4c9ec
commit
fd5be55a7b
@ -1,62 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import BaseLLM
|
||||
from langchain.utilities.vertexai import create_retry_decorator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
||||
try:
|
||||
import google.api_core.exceptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import google-api-core python package. "
|
||||
"Please install it with `pip install google-api-core`."
|
||||
)
|
||||
|
||||
multiplier = 2
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
max_retries = 10
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(max_retries),
|
||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
def completion_with_retry(
|
||||
llm: GooglePalm,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = create_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _generate_with_retry(**kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(**kwargs)
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.generate_text(*args, **kwargs)
|
||||
|
||||
return _generate_with_retry(**kwargs)
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||
@ -94,6 +64,8 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||
not return the full n completions if duplicates are generated."""
|
||||
max_retries: int = 6
|
||||
"""The maximum number of retries to make when generating."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
@ -144,7 +116,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
completion = generate_with_retry(
|
||||
completion = completion_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
@ -170,3 +142,17 @@ class GooglePalm(BaseLLM, BaseModel):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "google_palm"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
text: The string input to tokenize.
|
||||
|
||||
Returns:
|
||||
The integer number of tokens in the text.
|
||||
"""
|
||||
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
||||
return result["token_count"]
|
||||
|
@ -4,13 +4,11 @@ from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
@ -20,8 +18,9 @@ from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.utilities.vertexai import (
|
||||
create_retry_decorator,
|
||||
get_client_info,
|
||||
init_vertexai,
|
||||
raise_vertex_import_error,
|
||||
@ -65,27 +64,6 @@ def is_codey_model(model_name: str) -> bool:
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: VertexAI,
|
||||
*,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
import google.api_core
|
||||
|
||||
errors = [
|
||||
google.api_core.exceptions.ResourceExhausted,
|
||||
google.api_core.exceptions.ServiceUnavailable,
|
||||
google.api_core.exceptions.Aborted,
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
@ -93,7 +71,7 @@ def completion_with_retry(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
@ -109,7 +87,9 @@ def stream_completion_with_retry(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
retry_decorator = create_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
@ -125,7 +105,7 @@ async def acompletion_with_retry(
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
||||
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
|
@ -1,12 +1,43 @@
|
||||
"""Utilities to init Vertex AI."""
|
||||
from importlib import metadata
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
|
||||
def create_retry_decorator(
|
||||
llm: BaseLLM,
|
||||
*,
|
||||
max_retries: int = 1,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
||||
import google.api_core
|
||||
|
||||
errors = [
|
||||
google.api_core.exceptions.ResourceExhausted,
|
||||
google.api_core.exceptions.ServiceUnavailable,
|
||||
google.api_core.exceptions.Aborted,
|
||||
google.api_core.exceptions.DeadlineExceeded,
|
||||
google.api_core.exceptions.GoogleAPIError,
|
||||
]
|
||||
decorator = create_base_retry_decorator(
|
||||
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
|
||||
"""Raise ImportError related to Vertex SDK being not available.
|
||||
|
||||
|
@ -6,6 +6,8 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain.llms.google_palm import GooglePalm
|
||||
from langchain.llms.loading import load_llm
|
||||
|
||||
@ -15,6 +17,22 @@ def test_google_palm_call() -> None:
|
||||
llm = GooglePalm(max_output_tokens=10)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "google_palm"
|
||||
assert llm.model_name == "models/text-bison-001"
|
||||
|
||||
|
||||
def test_google_palm_generate() -> None:
|
||||
llm = GooglePalm(temperature=0.3, n=2)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
def test_google_palm_get_num_tokens() -> None:
|
||||
llm = GooglePalm()
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
||||
|
||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user