add vertex prod features (#10910)

- chat vertex async
- vertex stream
- vertex full generation info
- vertex use server-side stopping
- model garden async
- update docs for all the above

in follow up will add
[] chat vertex full generation info
[] chat vertex retries
[] scheduled tests
This commit is contained in:
Bagatur
2023-09-22 01:44:09 -07:00
committed by GitHub
parent dccc20b402
commit cab55e9bc1
10 changed files with 721 additions and 267 deletions

View File

@@ -1,10 +1,14 @@
"""Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.pydantic_v1 import root_validator
@@ -30,6 +34,8 @@ if TYPE_CHECKING:
InputOutputTextPair,
)
logger = logging.getLogger(__name__)
@dataclass
class _ChatHistory:
@@ -116,7 +122,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"""`Vertex AI` Chat large language models API."""
model_name: str = "chat-bison"
streaming: bool = False
"Underlying model name."
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -177,6 +183,42 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
does not support context.
stop: The list of stop words (optional).
run_manager: The CallbackManager for LLM run, it's not used at the moment.
Returns:
The ChatResult that contains outputs generated by the model.
Raises:
ValueError: if the last message in the list is not from human.
"""
if "stream" in kwargs:
kwargs.pop("stream")
logger.warning("ChatVertexAI does not currently support async streaming.")
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = {**self._default_params, **kwargs}
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, params)
response = await chat.send_message_async(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
def _stream(
self,
messages: List[BaseMessage],

View File

@@ -1,28 +1,58 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Union,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
Generation,
LLMResult,
)
from langchain.schema.output import GenerationChunk
from langchain.utilities.vertexai import (
init_vertexai,
raise_vertex_import_error,
)
if TYPE_CHECKING:
from google.cloud.aiplatform.gapic import PredictionServiceClient
from vertexai.language_models._language_models import _LanguageModel
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
from vertexai.language_models._language_models import (
TextGenerationResponse,
_LanguageModel,
)
def _response_to_generation(
response: TextGenerationResponse,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def is_codey_model(model_name: str) -> bool:
@@ -36,7 +66,13 @@ def is_codey_model(model_name: str) -> bool:
return "code" in model_name
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
def _create_retry_decorator(
llm: VertexAI,
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
import google.api_core
errors = [
@@ -46,14 +82,19 @@ def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]:
google.api_core.exceptions.DeadlineExceeded,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries # type: ignore
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return decorator
def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
def completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
@@ -62,6 +103,38 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
return _completion_with_retry(*args, **kwargs)
def stream_completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict_streaming(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
async def acompletion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
return await llm.client.predict_async(*args, **kwargs)
return await _acompletion_with_retry(*args, **kwargs)
class _VertexAIBase(BaseModel):
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
@@ -110,6 +183,11 @@ class _VertexAICommon(_VertexAIBase):
"The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."
streaming: bool = False
@property
def _llm_type(self) -> str:
return "vertexai"
@property
def is_codey_model(self) -> bool:
@@ -135,17 +213,6 @@ class _VertexAICommon(_VertexAIBase):
"top_p": self.top_p,
}
def _predict(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
params = {**self._default_params, **kwargs}
res = completion_with_retry(self, prompt, **params) # type: ignore
return self._enforce_stop_words(res.text, stop)
@property
def _llm_type(self) -> str:
return "vertexai"
@classmethod
def _try_init_vertexai(cls, values: Dict) -> None:
allowed_params = ["project", "location", "credentials"]
@@ -154,13 +221,14 @@ class _VertexAICommon(_VertexAIBase):
return None
class VertexAI(_VertexAICommon, LLM):
class VertexAI(_VertexAICommon, BaseLLM):
"""Google Vertex AI large language models."""
model_name: str = "text-bison"
"The name of the Vertex AI large language model."
tuned_model_name: Optional[str] = None
"The name of a tuned model. If provided, model_name is ignored."
streaming: bool = False
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -191,51 +259,78 @@ class VertexAI(_VertexAICommon, LLM):
raise_vertex_import_error()
return values
def _call(
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> LLMResult:
stop_sequences = stop or self.stop
should_stream = stream if stream is not None else self.streaming
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
generations = []
for prompt in prompts:
if should_stream:
generation = GenerationChunk(text="")
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
generation += chunk
generations.append([generation])
else:
res = completion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(res)])
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
stop_sequences = stop or self.stop
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
generations = []
for prompt in prompts:
res = await acompletion_with_retry(
self, prompt, run_manager=run_manager, **params
)
generations.append([_response_to_generation(res)])
return LLMResult(generations=generations)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A Callbackmanager for LLM run, optional.
Returns:
The string generated by the model.
"""
return self._predict(prompt, stop, **kwargs)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._call, prompt, stop)
)
) -> Iterator[GenerationChunk]:
stop_sequences = stop or self.stop
params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
for stream_resp in stream_completion_with_retry(
self, prompt, run_manager=run_manager, **params
):
chunk = _response_to_generation(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
class VertexAIModelGarden(_VertexAIBase, LLM):
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
client: "PredictionServiceClient" = None #: :meta private:
async_client: "PredictionServiceAsyncClient" = None #: :meta private:
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
@@ -247,7 +342,11 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
try:
from google.cloud.aiplatform.gapic import PredictionServiceClient
from google.api_core.client_options import ClientOptions
from google.cloud.aiplatform.gapic import (
PredictionServiceAsyncClient,
PredictionServiceClient,
)
except ImportError:
raise_vertex_import_error()
@@ -256,38 +355,19 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
"A GCP project should be provided to run inference on Model Garden!"
)
client_options = {
"api_endpoint": f"{values['location']}-aiplatform.googleapis.com"
}
client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
)
values["client"] = PredictionServiceClient(client_options=client_options)
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options
)
return values
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A Callbackmanager for LLM run, optional.
Returns:
The string generated by the model.
"""
result = self._generate(
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
)
return result.generations[0][0].text
def _generate(
self,
prompts: List[str],
@@ -331,23 +411,47 @@ class VertexAIModelGarden(_VertexAIBase, LLM):
)
return LLMResult(generations=generations)
async def _acall(
async def _agenerate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._call, prompt, stop)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
endpoint = self.async_client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
response = await self.async_client.predict(
endpoint=endpoint, instances=predict_instances
)
generations: List[List[Generation]] = []
for result in response.predictions:
generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result]
)
return LLMResult(generations=generations)

View File

@@ -13,6 +13,7 @@ import pytest
from langchain.chat_models import ChatVertexAI
from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples
from langchain.schema import LLMResult
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
@@ -26,10 +27,22 @@ def test_vertexai_single_call(model_name: str) -> None:
response = model([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert model._llm_type == "vertexai"
assert model._llm_type == "chat-vertexai"
assert model.model_name == model.client._model_id
@pytest.mark.asyncio
async def test_vertexai_agenerate() -> None:
model = ChatVertexAI(temperature=0)
message = HumanMessage(content="Hello")
response = await model.agenerate([[message]])
assert isinstance(response, LLMResult)
assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore
sync_response = model.generate([[message]])
assert response.generations[0][0] == sync_response.generations[0][0]
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()
raw_context = (

View File

@@ -14,7 +14,6 @@ def test_embedding_documents() -> None:
output = model.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id
@@ -40,5 +39,4 @@ def test_paginated_texts() -> None:
output = model.embed_documents(documents)
assert len(output) == 8
assert len(output[0]) == 768
assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id

View File

@@ -9,18 +9,49 @@ Your end-user credentials would be used to make the calls (make sure you've run
"""
import os
import pytest
from langchain.llms import VertexAI, VertexAIModelGarden
from langchain.schema import LLMResult
def test_vertex_call() -> None:
llm = VertexAI()
llm = VertexAI(temperature=0)
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id
def test_vertex_generate() -> None:
llm = VertexAI(temperate=0)
output = llm.generate(["Please say foo:"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_vertex_agenerate() -> None:
llm = VertexAI(temperate=0)
output = await llm.agenerate(["Please say foo:"])
assert isinstance(output, LLMResult)
def test_vertext_stream() -> None:
llm = VertexAI(temperate=0)
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
@pytest.mark.asyncio
async def test_vertex_consistency() -> None:
llm = VertexAI(temperate=0)
output = llm.generate(["Please say foo:"])
streaming_output = llm.generate(["Please say foo:"], stream=True)
async_output = await llm.agenerate(["Please say foo:"])
assert output.generations[0][0].text == streaming_output.generations[0][0].text
assert output.generations[0][0].text == async_output.generations[0][0].text
def test_model_garden() -> None:
"""In order to run this test, you should provide an endpoint name.
@@ -37,7 +68,7 @@ def test_model_garden() -> None:
assert llm._llm_type == "vertexai_model_garden"
def test_model_garden_batch() -> None:
def test_model_garden_generate() -> None:
"""In order to run this test, you should provide an endpoint name.
Example:
@@ -47,6 +78,16 @@ def test_model_garden_batch() -> None:
endpoint_id = os.environ["ENDPOINT_ID"]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
output = llm._generate(["What is the meaning of life?", "How much is 2+2"])
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2
@pytest.mark.asyncio
async def test_model_garden_agenerate() -> None:
endpoint_id = os.environ["ENDPOINT_ID"]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2