mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 22:15:08 +00:00
added get_num_tokens to GooglePalm (#14282)
added get_num_tokens to GooglePalm + a little bit of refactoring
This commit is contained in:
parent
c215a4c9ec
commit
fd5be55a7b
@ -1,62 +1,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
from typing import Any, Dict, List, Optional
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_core.outputs import Generation, LLMResult
|
from langchain_core.outputs import Generation, LLMResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||||
from tenacity import (
|
|
||||||
before_sleep_log,
|
|
||||||
retry,
|
|
||||||
retry_if_exception_type,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain.llms import BaseLLM
|
from langchain.llms import BaseLLM
|
||||||
|
from langchain.utilities.vertexai import create_retry_decorator
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
def completion_with_retry(
|
||||||
def _create_retry_decorator() -> Callable[[Any], Any]:
|
llm: GooglePalm,
|
||||||
"""Returns a tenacity retry decorator, preconfigured to handle PaLM exceptions"""
|
*args: Any,
|
||||||
try:
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
import google.api_core.exceptions
|
**kwargs: Any,
|
||||||
except ImportError:
|
) -> Any:
|
||||||
raise ImportError(
|
|
||||||
"Could not import google-api-core python package. "
|
|
||||||
"Please install it with `pip install google-api-core`."
|
|
||||||
)
|
|
||||||
|
|
||||||
multiplier = 2
|
|
||||||
min_seconds = 1
|
|
||||||
max_seconds = 60
|
|
||||||
max_retries = 10
|
|
||||||
|
|
||||||
return retry(
|
|
||||||
reraise=True,
|
|
||||||
stop=stop_after_attempt(max_retries),
|
|
||||||
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
|
|
||||||
retry=(
|
|
||||||
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
|
|
||||||
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
|
|
||||||
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
|
|
||||||
),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_with_retry(llm: GooglePalm, **kwargs: Any) -> Any:
|
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator()
|
retry_decorator = create_retry_decorator(
|
||||||
|
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _generate_with_retry(**kwargs: Any) -> Any:
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
return llm.client.generate_text(**kwargs)
|
return llm.client.generate_text(*args, **kwargs)
|
||||||
|
|
||||||
return _generate_with_retry(**kwargs)
|
return _completion_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _strip_erroneous_leading_spaces(text: str) -> str:
|
def _strip_erroneous_leading_spaces(text: str) -> str:
|
||||||
@ -94,6 +64,8 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
n: int = 1
|
n: int = 1
|
||||||
"""Number of chat completions to generate for each prompt. Note that the API may
|
"""Number of chat completions to generate for each prompt. Note that the API may
|
||||||
not return the full n completions if duplicates are generated."""
|
not return the full n completions if duplicates are generated."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""The maximum number of retries to make when generating."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
@ -144,7 +116,7 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
completion = generate_with_retry(
|
completion = completion_with_retry(
|
||||||
self,
|
self,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -170,3 +142,17 @@ class GooglePalm(BaseLLM, BaseModel):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "google_palm"
|
return "google_palm"
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Get the number of tokens present in the text.
|
||||||
|
|
||||||
|
Useful for checking if an input will fit in a model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The string input to tokenize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The integer number of tokens in the text.
|
||||||
|
"""
|
||||||
|
result = self.client.count_text_tokens(model=self.model_name, prompt=text)
|
||||||
|
return result["token_count"]
|
||||||
|
@ -4,13 +4,11 @@ from concurrent.futures import Executor, ThreadPoolExecutor
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||||
@ -20,8 +18,9 @@ from langchain.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.utilities.vertexai import (
|
from langchain.utilities.vertexai import (
|
||||||
|
create_retry_decorator,
|
||||||
get_client_info,
|
get_client_info,
|
||||||
init_vertexai,
|
init_vertexai,
|
||||||
raise_vertex_import_error,
|
raise_vertex_import_error,
|
||||||
@ -65,27 +64,6 @@ def is_codey_model(model_name: str) -> bool:
|
|||||||
return "code" in model_name
|
return "code" in model_name
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(
|
|
||||||
llm: VertexAI,
|
|
||||||
*,
|
|
||||||
run_manager: Optional[
|
|
||||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
|
||||||
] = None,
|
|
||||||
) -> Callable[[Any], Any]:
|
|
||||||
import google.api_core
|
|
||||||
|
|
||||||
errors = [
|
|
||||||
google.api_core.exceptions.ResourceExhausted,
|
|
||||||
google.api_core.exceptions.ServiceUnavailable,
|
|
||||||
google.api_core.exceptions.Aborted,
|
|
||||||
google.api_core.exceptions.DeadlineExceeded,
|
|
||||||
]
|
|
||||||
decorator = create_base_retry_decorator(
|
|
||||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
|
||||||
)
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def completion_with_retry(
|
def completion_with_retry(
|
||||||
llm: VertexAI,
|
llm: VertexAI,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -93,7 +71,7 @@ def completion_with_retry(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
@ -109,7 +87,9 @@ def stream_completion_with_retry(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = create_retry_decorator(
|
||||||
|
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
@ -125,7 +105,7 @@ async def acompletion_with_retry(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use tenacity to retry the completion call."""
|
"""Use tenacity to retry the completion call."""
|
||||||
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
@ -1,12 +1,43 @@
|
|||||||
"""Utilities to init Vertex AI."""
|
"""Utilities to init Vertex AI."""
|
||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||||
from google.auth.credentials import Credentials
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
|
def create_retry_decorator(
|
||||||
|
llm: BaseLLM,
|
||||||
|
*,
|
||||||
|
max_retries: int = 1,
|
||||||
|
run_manager: Optional[
|
||||||
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
|
] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
"""Creates a retry decorator for Vertex / Palm LLMs."""
|
||||||
|
import google.api_core
|
||||||
|
|
||||||
|
errors = [
|
||||||
|
google.api_core.exceptions.ResourceExhausted,
|
||||||
|
google.api_core.exceptions.ServiceUnavailable,
|
||||||
|
google.api_core.exceptions.Aborted,
|
||||||
|
google.api_core.exceptions.DeadlineExceeded,
|
||||||
|
google.api_core.exceptions.GoogleAPIError,
|
||||||
|
]
|
||||||
|
decorator = create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
|
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
|
||||||
"""Raise ImportError related to Vertex SDK being not available.
|
"""Raise ImportError related to Vertex SDK being not available.
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ Note: This test must be run with the GOOGLE_API_KEY environment variable set to
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
from langchain.llms.google_palm import GooglePalm
|
from langchain.llms.google_palm import GooglePalm
|
||||||
from langchain.llms.loading import load_llm
|
from langchain.llms.loading import load_llm
|
||||||
|
|
||||||
@ -15,6 +17,22 @@ def test_google_palm_call() -> None:
|
|||||||
llm = GooglePalm(max_output_tokens=10)
|
llm = GooglePalm(max_output_tokens=10)
|
||||||
output = llm("Say foo:")
|
output = llm("Say foo:")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
assert llm._llm_type == "google_palm"
|
||||||
|
assert llm.model_name == "models/text-bison-001"
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_palm_generate() -> None:
|
||||||
|
llm = GooglePalm(temperature=0.3, n=2)
|
||||||
|
output = llm.generate(["Say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
assert len(output.generations[0]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_palm_get_num_tokens() -> None:
|
||||||
|
llm = GooglePalm()
|
||||||
|
output = llm.get_num_tokens("How are you?")
|
||||||
|
assert output == 4
|
||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user