mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 20:41:52 +00:00
Adding support for async (_acall) for VertexAICommon LLM (#5588)
# Adding support for async (_acall) for VertexAICommon LLM This PR implements the `_acall` method under `_VertexAICommon`. Because VertexAI itself does not provide an async interface, I implemented it via a ThreadPoolExecutor that can delegate execution of VertexAI calls to other threads. Twitter handle: @polecitoem : ) ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: fyi - @agola11 for async functionality fyi - @Ark-kun from VertexAI
This commit is contained in:
parent
cbd759aaeb
commit
6370808d41
@ -1,9 +1,14 @@
|
||||
"""Wrapper around Google VertexAI models."""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
import asyncio
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utilities.vertexai import (
|
||||
@ -43,6 +48,10 @@ class _VertexAICommon(BaseModel):
|
||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||
"when making API calls. If not provided, credentials will be ascertained from "
|
||||
"the environment."
|
||||
request_parallelism: int = 5
|
||||
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
||||
"Default is 5."
|
||||
task_executor: ClassVar[Optional[Executor]] = None
|
||||
|
||||
@property
|
||||
def is_codey_model(self) -> bool:
|
||||
@ -81,6 +90,12 @@ class _VertexAICommon(BaseModel):
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai"
|
||||
|
||||
@classmethod
|
||||
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
||||
if cls.task_executor is None:
|
||||
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
|
||||
return cls.task_executor
|
||||
|
||||
@classmethod
|
||||
def _try_init_vertexai(cls, values: Dict) -> None:
|
||||
allowed_params = ["project", "location", "credentials"]
|
||||
@ -121,6 +136,26 @@ class VertexAI(_VertexAICommon, LLM):
|
||||
raise_vertex_import_error()
|
||||
return values
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Call Vertex model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A callback manager for async interaction with LLMs.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
return await asyncio.wrap_future(
|
||||
self._get_task_executor().submit(self._predict, prompt, stop)
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
|
Loading…
Reference in New Issue
Block a user