mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 17:45:25 +00:00
BUFIX: add support for various OSS images from Vertex Model Garden (#13917)
- **Description:** add support for various OSS images from Model Garden - **Issue:** #13370
This commit is contained in:
parent
e186637921
commit
25387db432
@ -32,6 +32,8 @@ if TYPE_CHECKING:
|
||||
PredictionServiceAsyncClient,
|
||||
PredictionServiceClient,
|
||||
)
|
||||
from google.cloud.aiplatform.models import Prediction
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
from vertexai.language_models._language_models import (
|
||||
TextGenerationResponse,
|
||||
_LanguageModel,
|
||||
@ -370,9 +372,11 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
endpoint_id: str
|
||||
"A name of an endpoint where the model has been deployed."
|
||||
allowed_model_args: Optional[List[str]] = None
|
||||
"""Allowed optional args to be passed to the model."""
|
||||
"Allowed optional args to be passed to the model."
|
||||
prompt_arg: str = "prompt"
|
||||
result_arg: str = "generated_text"
|
||||
result_arg: Optional[str] = "generated_text"
|
||||
"Set result_arg to None if output of the model is expected to be a string."
|
||||
"Otherwise, if it's a dict, provided an argument that contains the result."
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -386,7 +390,7 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
|
||||
if values["project"] is None:
|
||||
if not values["project"]:
|
||||
raise ValueError(
|
||||
"A GCP project should be provided to run inference on Model Garden!"
|
||||
)
|
||||
@ -401,12 +405,42 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
values["async_client"] = PredictionServiceAsyncClient(
|
||||
client_options=client_options, client_info=client_info
|
||||
)
|
||||
values["endpoint_path"] = values["client"].endpoint_path(
|
||||
project=values["project"],
|
||||
location=values["location"],
|
||||
endpoint=values["endpoint_id"],
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "vertexai_model_garden"
|
||||
|
||||
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
|
||||
try:
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"protobuf package not found, please install it with"
|
||||
" `pip install protobuf`"
|
||||
)
|
||||
instances = []
|
||||
for prompt in prompts:
|
||||
if self.allowed_model_args:
|
||||
instance = {
|
||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||
}
|
||||
else:
|
||||
instance = {}
|
||||
instance[self.prompt_arg] = prompt
|
||||
instances.append(instance)
|
||||
|
||||
predict_instances = [
|
||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||
]
|
||||
return predict_instances
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -415,41 +449,43 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
try:
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"protobuf package not found, please install it with"
|
||||
" `pip install protobuf`"
|
||||
)
|
||||
instances = self._prepare_request(prompts, **kwargs)
|
||||
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
|
||||
return self._parse_response(response)
|
||||
|
||||
instances = []
|
||||
for prompt in prompts:
|
||||
if self.allowed_model_args:
|
||||
instance = {
|
||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||
}
|
||||
else:
|
||||
instance = {}
|
||||
instance[self.prompt_arg] = prompt
|
||||
instances.append(instance)
|
||||
|
||||
predict_instances = [
|
||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||
]
|
||||
|
||||
endpoint = self.client.endpoint_path(
|
||||
project=self.project, location=self.location, endpoint=self.endpoint_id
|
||||
)
|
||||
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
|
||||
def _parse_response(self, predictions: "Prediction") -> LLMResult:
|
||||
generations: List[List[Generation]] = []
|
||||
for result in response.predictions:
|
||||
for result in predictions.predictions:
|
||||
generations.append(
|
||||
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
||||
[
|
||||
Generation(text=self._parse_prediction(prediction))
|
||||
for prediction in result
|
||||
]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _parse_prediction(self, prediction: Any) -> str:
|
||||
if isinstance(prediction, str):
|
||||
return prediction
|
||||
|
||||
if self.result_arg:
|
||||
try:
|
||||
return prediction[self.result_arg]
|
||||
except KeyError:
|
||||
if isinstance(prediction, str):
|
||||
error_desc = (
|
||||
"Provided non-None `result_arg` (result_arg="
|
||||
f"{self.result_arg}). But got prediction of type "
|
||||
f"{type(prediction)} instead of dict. Most probably, you"
|
||||
"need to set `result_arg=None` during VertexAIModelGarden "
|
||||
"initialization."
|
||||
)
|
||||
raise ValueError(error_desc)
|
||||
else:
|
||||
raise ValueError(f"{self.result_arg} key not found in prediction!")
|
||||
|
||||
return prediction
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -458,39 +494,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
try:
|
||||
from google.protobuf import json_format
|
||||
from google.protobuf.struct_pb2 import Value
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"protobuf package not found, please install it with"
|
||||
" `pip install protobuf`"
|
||||
)
|
||||
|
||||
instances = []
|
||||
for prompt in prompts:
|
||||
if self.allowed_model_args:
|
||||
instance = {
|
||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||
}
|
||||
else:
|
||||
instance = {}
|
||||
instance[self.prompt_arg] = prompt
|
||||
instances.append(instance)
|
||||
|
||||
predict_instances = [
|
||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||
]
|
||||
|
||||
endpoint = self.async_client.endpoint_path(
|
||||
project=self.project, location=self.location, endpoint=self.endpoint_id
|
||||
)
|
||||
instances = self._prepare_request(prompts, **kwargs)
|
||||
response = await self.async_client.predict(
|
||||
endpoint=endpoint, instances=predict_instances
|
||||
endpoint=self.endpoint_path, instances=instances
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
for result in response.predictions:
|
||||
generations.append(
|
||||
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
return self._parse_response(response)
|
||||
|
@ -8,6 +8,7 @@ Your end-user credentials would be used to make the calls (make sure you've run
|
||||
`gcloud auth login` first).
|
||||
"""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
@ -71,40 +72,79 @@ async def test_vertex_consistency() -> None:
|
||||
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||
|
||||
|
||||
def test_model_garden() -> None:
|
||||
"""In order to run this test, you should provide an endpoint name.
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
def test_model_garden(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
"""In order to run this test, you should provide endpoint names.
|
||||
|
||||
Example:
|
||||
export ENDPOINT_ID=...
|
||||
export FALCON_ENDPOINT_ID=...
|
||||
export LLAMA_ENDPOINT_ID=...
|
||||
export PROJECT=...
|
||||
"""
|
||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = llm("What is the meaning of life?")
|
||||
assert isinstance(output, str)
|
||||
assert llm._llm_type == "vertexai_model_garden"
|
||||
|
||||
|
||||
def test_model_garden_generate() -> None:
|
||||
"""In order to run this test, you should provide an endpoint name.
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
def test_model_garden_generate(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
"""In order to run this test, you should provide endpoint names.
|
||||
|
||||
Example:
|
||||
export ENDPOINT_ID=...
|
||||
export FALCON_ENDPOINT_ID=...
|
||||
export LLAMA_ENDPOINT_ID=...
|
||||
export PROJECT=...
|
||||
"""
|
||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
async def test_model_garden_agenerate() -> None:
|
||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint_os_variable_name,result_arg",
|
||||
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||
)
|
||||
async def test_model_garden_agenerate(
|
||||
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||
) -> None:
|
||||
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||
project = os.environ["PROJECT"]
|
||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||
location = "europe-west4"
|
||||
llm = VertexAIModelGarden(
|
||||
endpoint_id=endpoint_id,
|
||||
project=project,
|
||||
result_arg=result_arg,
|
||||
location=location,
|
||||
)
|
||||
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 2
|
||||
|
Loading…
Reference in New Issue
Block a user