BUFIX: add support for various OSS images from Vertex Model Garden (#13917)

- **Description:** add support for various OSS images from Model
Garden
  - **Issue:** #13370
This commit is contained in:
Leonid Kuligin 2023-11-27 19:31:53 +01:00 committed by GitHub
parent e186637921
commit 25387db432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 79 deletions

View File

@ -32,6 +32,8 @@ if TYPE_CHECKING:
PredictionServiceAsyncClient, PredictionServiceAsyncClient,
PredictionServiceClient, PredictionServiceClient,
) )
from google.cloud.aiplatform.models import Prediction
from google.protobuf.struct_pb2 import Value
from vertexai.language_models._language_models import ( from vertexai.language_models._language_models import (
TextGenerationResponse, TextGenerationResponse,
_LanguageModel, _LanguageModel,
@ -370,9 +372,11 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
endpoint_id: str endpoint_id: str
"A name of an endpoint where the model has been deployed." "A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None allowed_model_args: Optional[List[str]] = None
"""Allowed optional args to be passed to the model.""" "Allowed optional args to be passed to the model."
prompt_arg: str = "prompt" prompt_arg: str = "prompt"
result_arg: str = "generated_text" result_arg: Optional[str] = "generated_text"
"Set result_arg to None if output of the model is expected to be a string."
"Otherwise, if it's a dict, provided an argument that contains the result."
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -386,7 +390,7 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
except ImportError: except ImportError:
raise_vertex_import_error() raise_vertex_import_error()
if values["project"] is None: if not values["project"]:
raise ValueError( raise ValueError(
"A GCP project should be provided to run inference on Model Garden!" "A GCP project should be provided to run inference on Model Garden!"
) )
@ -401,12 +405,42 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
values["async_client"] = PredictionServiceAsyncClient( values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info client_options=client_options, client_info=client_info
) )
values["endpoint_path"] = values["client"].endpoint_path(
project=values["project"],
location=values["location"],
endpoint=values["endpoint_id"],
)
return values return values
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "vertexai_model_garden" return "vertexai_model_garden"
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
return predict_instances
def _generate( def _generate(
self, self,
prompts: List[str], prompts: List[str],
@ -415,41 +449,43 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
try: instances = self._prepare_request(prompts, **kwargs)
from google.protobuf import json_format response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
from google.protobuf.struct_pb2 import Value return self._parse_response(response)
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
instances = [] def _parse_response(self, predictions: "Prediction") -> LLMResult:
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
endpoint = self.client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
generations: List[List[Generation]] = [] generations: List[List[Generation]] = []
for result in response.predictions: for result in predictions.predictions:
generations.append( generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result] [
Generation(text=self._parse_prediction(prediction))
for prediction in result
]
) )
return LLMResult(generations=generations) return LLMResult(generations=generations)
def _parse_prediction(self, prediction: Any) -> str:
if isinstance(prediction, str):
return prediction
if self.result_arg:
try:
return prediction[self.result_arg]
except KeyError:
if isinstance(prediction, str):
error_desc = (
"Provided non-None `result_arg` (result_arg="
f"{self.result_arg}). But got prediction of type "
f"{type(prediction)} instead of dict. Most probably, you"
"need to set `result_arg=None` during VertexAIModelGarden "
"initialization."
)
raise ValueError(error_desc)
else:
raise ValueError(f"{self.result_arg} key not found in prediction!")
return prediction
async def _agenerate( async def _agenerate(
self, self,
prompts: List[str], prompts: List[str],
@ -458,39 +494,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
try: instances = self._prepare_request(prompts, **kwargs)
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
endpoint = self.async_client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
response = await self.async_client.predict( response = await self.async_client.predict(
endpoint=endpoint, instances=predict_instances endpoint=self.endpoint_path, instances=instances
) )
generations: List[List[Generation]] = [] return self._parse_response(response)
for result in response.predictions:
generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result]
)
return LLMResult(generations=generations)

View File

@ -8,6 +8,7 @@ Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first). `gcloud auth login` first).
""" """
import os import os
from typing import Optional
import pytest import pytest
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
@ -71,40 +72,79 @@ async def test_vertex_consistency() -> None:
assert output.generations[0][0].text == async_output.generations[0][0].text assert output.generations[0][0].text == async_output.generations[0][0].text
def test_model_garden() -> None: @pytest.mark.parametrize(
"""In order to run this test, you should provide an endpoint name. "endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.
Example: Example:
export ENDPOINT_ID=... export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=... export PROJECT=...
""" """
endpoint_id = os.environ["ENDPOINT_ID"] endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"] project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm("What is the meaning of life?") output = llm("What is the meaning of life?")
assert isinstance(output, str) assert isinstance(output, str)
assert llm._llm_type == "vertexai_model_garden" assert llm._llm_type == "vertexai_model_garden"
def test_model_garden_generate() -> None: @pytest.mark.parametrize(
"""In order to run this test, you should provide an endpoint name. "endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
def test_model_garden_generate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
"""In order to run this test, you should provide endpoint names.
Example: Example:
export ENDPOINT_ID=... export FALCON_ENDPOINT_ID=...
export LLAMA_ENDPOINT_ID=...
export PROJECT=... export PROJECT=...
""" """
endpoint_id = os.environ["ENDPOINT_ID"] endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"] project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = llm.generate(["What is the meaning of life?", "How much is 2+2"]) output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult) assert isinstance(output, LLMResult)
assert len(output.generations) == 2 assert len(output.generations) == 2
async def test_model_garden_agenerate() -> None: @pytest.mark.asyncio
endpoint_id = os.environ["ENDPOINT_ID"] @pytest.mark.parametrize(
"endpoint_os_variable_name,result_arg",
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
)
async def test_model_garden_agenerate(
endpoint_os_variable_name: str, result_arg: Optional[str]
) -> None:
endpoint_id = os.environ[endpoint_os_variable_name]
project = os.environ["PROJECT"] project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) location = "europe-west4"
llm = VertexAIModelGarden(
endpoint_id=endpoint_id,
project=project,
result_arg=result_arg,
location=location,
)
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult) assert isinstance(output, LLMResult)
assert len(output.generations) == 2 assert len(output.generations) == 2