From 6370808d41eda0d056375015fda9284e9f01280c Mon Sep 17 00:00:00 2001 From: Pablo Date: Wed, 28 Jun 2023 23:07:41 -0700 Subject: [PATCH] Adding support for async (_acall) for VertexAICommon LLM (#5588) # Adding support for async (_acall) for VertexAICommon LLM This PR implements the `_acall` method under `_VertexAICommon`. Because VertexAI itself does not provide an async interface, I implemented it via a ThreadPoolExecutor that can delegate execution of VertexAI calls to other threads. Twitter handle: @polecitoem : ) ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: fyi - @agola11 for async functionality fyi - @Ark-kun from VertexAI --- langchain/llms/vertexai.py | 39 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py index 0f85920cfa1..22f5cc42981 100644 --- a/langchain/llms/vertexai.py +++ b/langchain/llms/vertexai.py @@ -1,9 +1,14 @@ """Wrapper around Google VertexAI models.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional +import asyncio +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from pydantic import BaseModel, root_validator -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utilities.vertexai import ( @@ -43,6 +48,10 @@ class _VertexAICommon(BaseModel): "The default custom credentials (google.auth.credentials.Credentials) to use " "when making API calls. If not provided, credentials will be ascertained from " "the environment." + request_parallelism: int = 5 + "The amount of parallelism allowed for requests issued to VertexAI models. " + "Default is 5." + task_executor: ClassVar[Optional[Executor]] = None @property def is_codey_model(self) -> bool: @@ -81,6 +90,12 @@ class _VertexAICommon(BaseModel): def _llm_type(self) -> str: return "vertexai" + @classmethod + def _get_task_executor(cls, request_parallelism: int = 5) -> Executor: + if cls.task_executor is None: + cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism) + return cls.task_executor + @classmethod def _try_init_vertexai(cls, values: Dict) -> None: allowed_params = ["project", "location", "credentials"] @@ -121,6 +136,26 @@ class VertexAI(_VertexAICommon, LLM): raise_vertex_import_error() return values + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> str: + """Call Vertex model to get predictions based on the prompt. + + Args: + prompt: The prompt to pass into the model. + stop: A list of stop words (optional). + run_manager: A callback manager for async interaction with LLMs. + + Returns: + The string generated by the model. + """ + return await asyncio.wrap_future( + self._get_task_executor().submit(self._predict, prompt, stop) + ) + def _call( self, prompt: str,