Adding support for async (_acall) for VertexAICommon LLM (#5588)

# Adding support for async (_acall) for VertexAICommon LLM

This PR implements the `_acall` method under `_VertexAICommon`. Because
VertexAI itself does not provide an async interface, I implemented it
via a ThreadPoolExecutor that can delegate execution of VertexAI calls
to other threads.

Twitter handle: @polecitoem : )


## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

fyi - @agola11 for async functionality
fyi - @Ark-kun from VertexAI
This commit is contained in:
Pablo 2023-06-28 23:07:41 -07:00 committed by GitHub
parent cbd759aaeb
commit 6370808d41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,14 @@
"""Wrapper around Google VertexAI models."""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utilities.vertexai import (
@ -43,6 +48,10 @@ class _VertexAICommon(BaseModel):
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
task_executor: ClassVar[Optional[Executor]] = None
@property
def is_codey_model(self) -> bool:
@ -81,6 +90,12 @@ class _VertexAICommon(BaseModel):
def _llm_type(self) -> str:
return "vertexai"
@classmethod
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
if cls.task_executor is None:
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
return cls.task_executor
@classmethod
def _try_init_vertexai(cls, values: Dict) -> None:
allowed_params = ["project", "location", "credentials"]
@ -121,6 +136,26 @@ class VertexAI(_VertexAICommon, LLM):
raise_vertex_import_error()
return values
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._predict, prompt, stop)
)
def _call(
self,
prompt: str,