mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
Adding support for async (_acall) for VertexAICommon LLM (#5588)
# Adding support for async (_acall) for VertexAICommon LLM This PR implements the `_acall` method under `_VertexAICommon`. Because VertexAI itself does not provide an async interface, I implemented it via a ThreadPoolExecutor that can delegate execution of VertexAI calls to other threads. Twitter handle: @polecitoem : ) ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: fyi - @agola11 for async functionality fyi - @Ark-kun from VertexAI
This commit is contained in:
parent
cbd759aaeb
commit
6370808d41
@ -1,9 +1,14 @@
|
|||||||
"""Wrapper around Google VertexAI models."""
|
"""Wrapper around Google VertexAI models."""
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
import asyncio
|
||||||
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.utilities.vertexai import (
|
from langchain.utilities.vertexai import (
|
||||||
@ -43,6 +48,10 @@ class _VertexAICommon(BaseModel):
|
|||||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||||
"when making API calls. If not provided, credentials will be ascertained from "
|
"when making API calls. If not provided, credentials will be ascertained from "
|
||||||
"the environment."
|
"the environment."
|
||||||
|
request_parallelism: int = 5
|
||||||
|
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
||||||
|
"Default is 5."
|
||||||
|
task_executor: ClassVar[Optional[Executor]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_codey_model(self) -> bool:
|
def is_codey_model(self) -> bool:
|
||||||
@ -81,6 +90,12 @@ class _VertexAICommon(BaseModel):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "vertexai"
|
return "vertexai"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
||||||
|
if cls.task_executor is None:
|
||||||
|
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
|
||||||
|
return cls.task_executor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _try_init_vertexai(cls, values: Dict) -> None:
|
def _try_init_vertexai(cls, values: Dict) -> None:
|
||||||
allowed_params = ["project", "location", "credentials"]
|
allowed_params = ["project", "location", "credentials"]
|
||||||
@ -121,6 +136,26 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Call Vertex model to get predictions based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: A list of stop words (optional).
|
||||||
|
run_manager: A callback manager for async interaction with LLMs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
"""
|
||||||
|
return await asyncio.wrap_future(
|
||||||
|
self._get_task_executor().submit(self._predict, prompt, stop)
|
||||||
|
)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user