From 25387db4328c3531231561debf293c59eb458c7a Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Mon, 27 Nov 2023 19:31:53 +0100 Subject: [PATCH] BUFIX: add support for various OSS images from Vertex Model Garden (#13917) - **Description:** add support for various OSS images from Model Garden - **Issue:** #13370 --- libs/langchain/langchain/llms/vertexai.py | 137 +++++++++--------- .../integration_tests/llms/test_vertexai.py | 66 +++++++-- 2 files changed, 124 insertions(+), 79 deletions(-) diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index 4a3c67b7818..994cfdf6bcf 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -32,6 +32,8 @@ if TYPE_CHECKING: PredictionServiceAsyncClient, PredictionServiceClient, ) + from google.cloud.aiplatform.models import Prediction + from google.protobuf.struct_pb2 import Value from vertexai.language_models._language_models import ( TextGenerationResponse, _LanguageModel, @@ -370,9 +372,11 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): endpoint_id: str "A name of an endpoint where the model has been deployed." allowed_model_args: Optional[List[str]] = None - """Allowed optional args to be passed to the model.""" + "Allowed optional args to be passed to the model." prompt_arg: str = "prompt" - result_arg: str = "generated_text" + result_arg: Optional[str] = "generated_text" + "Set result_arg to None if output of the model is expected to be a string." + "Otherwise, if it's a dict, provided an argument that contains the result." @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -386,7 +390,7 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): except ImportError: raise_vertex_import_error() - if values["project"] is None: + if not values["project"]: raise ValueError( "A GCP project should be provided to run inference on Model Garden!" ) @@ -401,12 +405,42 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): values["async_client"] = PredictionServiceAsyncClient( client_options=client_options, client_info=client_info ) + values["endpoint_path"] = values["client"].endpoint_path( + project=values["project"], + location=values["location"], + endpoint=values["endpoint_id"], + ) return values @property def _llm_type(self) -> str: return "vertexai_model_garden" + def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]: + try: + from google.protobuf import json_format + from google.protobuf.struct_pb2 import Value + except ImportError: + raise ImportError( + "protobuf package not found, please install it with" + " `pip install protobuf`" + ) + instances = [] + for prompt in prompts: + if self.allowed_model_args: + instance = { + k: v for k, v in kwargs.items() if k in self.allowed_model_args + } + else: + instance = {} + instance[self.prompt_arg] = prompt + instances.append(instance) + + predict_instances = [ + json_format.ParseDict(instance_dict, Value()) for instance_dict in instances + ] + return predict_instances + def _generate( self, prompts: List[str], @@ -415,41 +449,43 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - try: - from google.protobuf import json_format - from google.protobuf.struct_pb2 import Value - except ImportError: - raise ImportError( - "protobuf package not found, please install it with" - " `pip install protobuf`" - ) + instances = self._prepare_request(prompts, **kwargs) + response = self.client.predict(endpoint=self.endpoint_path, instances=instances) + return self._parse_response(response) - instances = [] - for prompt in prompts: - if self.allowed_model_args: - instance = { - k: v for k, v in kwargs.items() if k in self.allowed_model_args - } - else: - instance = {} - instance[self.prompt_arg] = prompt - instances.append(instance) - - predict_instances = [ - json_format.ParseDict(instance_dict, Value()) for instance_dict in instances - ] - - endpoint = self.client.endpoint_path( - project=self.project, location=self.location, endpoint=self.endpoint_id - ) - response = self.client.predict(endpoint=endpoint, instances=predict_instances) + def _parse_response(self, predictions: "Prediction") -> LLMResult: generations: List[List[Generation]] = [] - for result in response.predictions: + for result in predictions.predictions: generations.append( - [Generation(text=prediction[self.result_arg]) for prediction in result] + [ + Generation(text=self._parse_prediction(prediction)) + for prediction in result + ] ) return LLMResult(generations=generations) + def _parse_prediction(self, prediction: Any) -> str: + if isinstance(prediction, str): + return prediction + + if self.result_arg: + try: + return prediction[self.result_arg] + except KeyError: + if isinstance(prediction, str): + error_desc = ( + "Provided non-None `result_arg` (result_arg=" + f"{self.result_arg}). But got prediction of type " + f"{type(prediction)} instead of dict. Most probably, you" + "need to set `result_arg=None` during VertexAIModelGarden " + "initialization." + ) + raise ValueError(error_desc) + else: + raise ValueError(f"{self.result_arg} key not found in prediction!") + + return prediction + async def _agenerate( self, prompts: List[str], @@ -458,39 +494,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM): **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - try: - from google.protobuf import json_format - from google.protobuf.struct_pb2 import Value - except ImportError: - raise ImportError( - "protobuf package not found, please install it with" - " `pip install protobuf`" - ) - - instances = [] - for prompt in prompts: - if self.allowed_model_args: - instance = { - k: v for k, v in kwargs.items() if k in self.allowed_model_args - } - else: - instance = {} - instance[self.prompt_arg] = prompt - instances.append(instance) - - predict_instances = [ - json_format.ParseDict(instance_dict, Value()) for instance_dict in instances - ] - - endpoint = self.async_client.endpoint_path( - project=self.project, location=self.location, endpoint=self.endpoint_id - ) + instances = self._prepare_request(prompts, **kwargs) response = await self.async_client.predict( - endpoint=endpoint, instances=predict_instances + endpoint=self.endpoint_path, instances=instances ) - generations: List[List[Generation]] = [] - for result in response.predictions: - generations.append( - [Generation(text=prediction[self.result_arg]) for prediction in result] - ) - return LLMResult(generations=generations) + return self._parse_response(response) diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index ef9c8fb1b53..6ddb7044874 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -8,6 +8,7 @@ Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). """ import os +from typing import Optional import pytest from langchain_core.outputs import LLMResult @@ -71,40 +72,79 @@ async def test_vertex_consistency() -> None: assert output.generations[0][0].text == async_output.generations[0][0].text -def test_model_garden() -> None: - """In order to run this test, you should provide an endpoint name. +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +def test_model_garden( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + """In order to run this test, you should provide endpoint names. Example: - export ENDPOINT_ID=... + export FALCON_ENDPOINT_ID=... + export LLAMA_ENDPOINT_ID=... export PROJECT=... """ - endpoint_id = os.environ["ENDPOINT_ID"] + endpoint_id = os.environ[endpoint_os_variable_name] project = os.environ["PROJECT"] - llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) output = llm("What is the meaning of life?") assert isinstance(output, str) assert llm._llm_type == "vertexai_model_garden" -def test_model_garden_generate() -> None: - """In order to run this test, you should provide an endpoint name. +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +def test_model_garden_generate( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + """In order to run this test, you should provide endpoint names. Example: - export ENDPOINT_ID=... + export FALCON_ENDPOINT_ID=... + export LLAMA_ENDPOINT_ID=... export PROJECT=... """ - endpoint_id = os.environ["ENDPOINT_ID"] + endpoint_id = os.environ[endpoint_os_variable_name] project = os.environ["PROJECT"] - llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) output = llm.generate(["What is the meaning of life?", "How much is 2+2"]) assert isinstance(output, LLMResult) assert len(output.generations) == 2 -async def test_model_garden_agenerate() -> None: - endpoint_id = os.environ["ENDPOINT_ID"] +@pytest.mark.asyncio +@pytest.mark.parametrize( + "endpoint_os_variable_name,result_arg", + [("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)], +) +async def test_model_garden_agenerate( + endpoint_os_variable_name: str, result_arg: Optional[str] +) -> None: + endpoint_id = os.environ[endpoint_os_variable_name] project = os.environ["PROJECT"] - llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + location = "europe-west4" + llm = VertexAIModelGarden( + endpoint_id=endpoint_id, + project=project, + result_arg=result_arg, + location=location, + ) output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) assert isinstance(output, LLMResult) assert len(output.generations) == 2