mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 17:45:25 +00:00
BUFIX: add support for various OSS images from Vertex Model Garden (#13917)
- **Description:** add support for various OSS images from Model Garden - **Issue:** #13370
This commit is contained in:
parent
e186637921
commit
25387db432
@ -32,6 +32,8 @@ if TYPE_CHECKING:
|
|||||||
PredictionServiceAsyncClient,
|
PredictionServiceAsyncClient,
|
||||||
PredictionServiceClient,
|
PredictionServiceClient,
|
||||||
)
|
)
|
||||||
|
from google.cloud.aiplatform.models import Prediction
|
||||||
|
from google.protobuf.struct_pb2 import Value
|
||||||
from vertexai.language_models._language_models import (
|
from vertexai.language_models._language_models import (
|
||||||
TextGenerationResponse,
|
TextGenerationResponse,
|
||||||
_LanguageModel,
|
_LanguageModel,
|
||||||
@ -370,9 +372,11 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
endpoint_id: str
|
endpoint_id: str
|
||||||
"A name of an endpoint where the model has been deployed."
|
"A name of an endpoint where the model has been deployed."
|
||||||
allowed_model_args: Optional[List[str]] = None
|
allowed_model_args: Optional[List[str]] = None
|
||||||
"""Allowed optional args to be passed to the model."""
|
"Allowed optional args to be passed to the model."
|
||||||
prompt_arg: str = "prompt"
|
prompt_arg: str = "prompt"
|
||||||
result_arg: str = "generated_text"
|
result_arg: Optional[str] = "generated_text"
|
||||||
|
"Set result_arg to None if output of the model is expected to be a string."
|
||||||
|
"Otherwise, if it's a dict, provided an argument that contains the result."
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -386,7 +390,7 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
|
|
||||||
if values["project"] is None:
|
if not values["project"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"A GCP project should be provided to run inference on Model Garden!"
|
"A GCP project should be provided to run inference on Model Garden!"
|
||||||
)
|
)
|
||||||
@ -401,12 +405,42 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
values["async_client"] = PredictionServiceAsyncClient(
|
values["async_client"] = PredictionServiceAsyncClient(
|
||||||
client_options=client_options, client_info=client_info
|
client_options=client_options, client_info=client_info
|
||||||
)
|
)
|
||||||
|
values["endpoint_path"] = values["client"].endpoint_path(
|
||||||
|
project=values["project"],
|
||||||
|
location=values["location"],
|
||||||
|
endpoint=values["endpoint_id"],
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "vertexai_model_garden"
|
return "vertexai_model_garden"
|
||||||
|
|
||||||
|
def _prepare_request(self, prompts: List[str], **kwargs: Any) -> List["Value"]:
|
||||||
|
try:
|
||||||
|
from google.protobuf import json_format
|
||||||
|
from google.protobuf.struct_pb2 import Value
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"protobuf package not found, please install it with"
|
||||||
|
" `pip install protobuf`"
|
||||||
|
)
|
||||||
|
instances = []
|
||||||
|
for prompt in prompts:
|
||||||
|
if self.allowed_model_args:
|
||||||
|
instance = {
|
||||||
|
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
instance = {}
|
||||||
|
instance[self.prompt_arg] = prompt
|
||||||
|
instances.append(instance)
|
||||||
|
|
||||||
|
predict_instances = [
|
||||||
|
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||||
|
]
|
||||||
|
return predict_instances
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -415,41 +449,43 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
try:
|
instances = self._prepare_request(prompts, **kwargs)
|
||||||
from google.protobuf import json_format
|
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
|
||||||
from google.protobuf.struct_pb2 import Value
|
return self._parse_response(response)
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"protobuf package not found, please install it with"
|
|
||||||
" `pip install protobuf`"
|
|
||||||
)
|
|
||||||
|
|
||||||
instances = []
|
def _parse_response(self, predictions: "Prediction") -> LLMResult:
|
||||||
for prompt in prompts:
|
|
||||||
if self.allowed_model_args:
|
|
||||||
instance = {
|
|
||||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
instance = {}
|
|
||||||
instance[self.prompt_arg] = prompt
|
|
||||||
instances.append(instance)
|
|
||||||
|
|
||||||
predict_instances = [
|
|
||||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
|
||||||
]
|
|
||||||
|
|
||||||
endpoint = self.client.endpoint_path(
|
|
||||||
project=self.project, location=self.location, endpoint=self.endpoint_id
|
|
||||||
)
|
|
||||||
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
|
|
||||||
generations: List[List[Generation]] = []
|
generations: List[List[Generation]] = []
|
||||||
for result in response.predictions:
|
for result in predictions.predictions:
|
||||||
generations.append(
|
generations.append(
|
||||||
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
[
|
||||||
|
Generation(text=self._parse_prediction(prediction))
|
||||||
|
for prediction in result
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
def _parse_prediction(self, prediction: Any) -> str:
|
||||||
|
if isinstance(prediction, str):
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
if self.result_arg:
|
||||||
|
try:
|
||||||
|
return prediction[self.result_arg]
|
||||||
|
except KeyError:
|
||||||
|
if isinstance(prediction, str):
|
||||||
|
error_desc = (
|
||||||
|
"Provided non-None `result_arg` (result_arg="
|
||||||
|
f"{self.result_arg}). But got prediction of type "
|
||||||
|
f"{type(prediction)} instead of dict. Most probably, you"
|
||||||
|
"need to set `result_arg=None` during VertexAIModelGarden "
|
||||||
|
"initialization."
|
||||||
|
)
|
||||||
|
raise ValueError(error_desc)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.result_arg} key not found in prediction!")
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -458,39 +494,8 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
try:
|
instances = self._prepare_request(prompts, **kwargs)
|
||||||
from google.protobuf import json_format
|
|
||||||
from google.protobuf.struct_pb2 import Value
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"protobuf package not found, please install it with"
|
|
||||||
" `pip install protobuf`"
|
|
||||||
)
|
|
||||||
|
|
||||||
instances = []
|
|
||||||
for prompt in prompts:
|
|
||||||
if self.allowed_model_args:
|
|
||||||
instance = {
|
|
||||||
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
instance = {}
|
|
||||||
instance[self.prompt_arg] = prompt
|
|
||||||
instances.append(instance)
|
|
||||||
|
|
||||||
predict_instances = [
|
|
||||||
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
|
||||||
]
|
|
||||||
|
|
||||||
endpoint = self.async_client.endpoint_path(
|
|
||||||
project=self.project, location=self.location, endpoint=self.endpoint_id
|
|
||||||
)
|
|
||||||
response = await self.async_client.predict(
|
response = await self.async_client.predict(
|
||||||
endpoint=endpoint, instances=predict_instances
|
endpoint=self.endpoint_path, instances=instances
|
||||||
)
|
)
|
||||||
generations: List[List[Generation]] = []
|
return self._parse_response(response)
|
||||||
for result in response.predictions:
|
|
||||||
generations.append(
|
|
||||||
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
|
||||||
)
|
|
||||||
return LLMResult(generations=generations)
|
|
||||||
|
@ -8,6 +8,7 @@ Your end-user credentials would be used to make the calls (make sure you've run
|
|||||||
`gcloud auth login` first).
|
`gcloud auth login` first).
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
@ -71,40 +72,79 @@ async def test_vertex_consistency() -> None:
|
|||||||
assert output.generations[0][0].text == async_output.generations[0][0].text
|
assert output.generations[0][0].text == async_output.generations[0][0].text
|
||||||
|
|
||||||
|
|
||||||
def test_model_garden() -> None:
|
@pytest.mark.parametrize(
|
||||||
"""In order to run this test, you should provide an endpoint name.
|
"endpoint_os_variable_name,result_arg",
|
||||||
|
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||||
|
)
|
||||||
|
def test_model_garden(
|
||||||
|
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||||
|
) -> None:
|
||||||
|
"""In order to run this test, you should provide endpoint names.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
export ENDPOINT_ID=...
|
export FALCON_ENDPOINT_ID=...
|
||||||
|
export LLAMA_ENDPOINT_ID=...
|
||||||
export PROJECT=...
|
export PROJECT=...
|
||||||
"""
|
"""
|
||||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||||
project = os.environ["PROJECT"]
|
project = os.environ["PROJECT"]
|
||||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
location = "europe-west4"
|
||||||
|
llm = VertexAIModelGarden(
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
|
project=project,
|
||||||
|
result_arg=result_arg,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
output = llm("What is the meaning of life?")
|
output = llm("What is the meaning of life?")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert llm._llm_type == "vertexai_model_garden"
|
assert llm._llm_type == "vertexai_model_garden"
|
||||||
|
|
||||||
|
|
||||||
def test_model_garden_generate() -> None:
|
@pytest.mark.parametrize(
|
||||||
"""In order to run this test, you should provide an endpoint name.
|
"endpoint_os_variable_name,result_arg",
|
||||||
|
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||||
|
)
|
||||||
|
def test_model_garden_generate(
|
||||||
|
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||||
|
) -> None:
|
||||||
|
"""In order to run this test, you should provide endpoint names.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
export ENDPOINT_ID=...
|
export FALCON_ENDPOINT_ID=...
|
||||||
|
export LLAMA_ENDPOINT_ID=...
|
||||||
export PROJECT=...
|
export PROJECT=...
|
||||||
"""
|
"""
|
||||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||||
project = os.environ["PROJECT"]
|
project = os.environ["PROJECT"]
|
||||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
location = "europe-west4"
|
||||||
|
llm = VertexAIModelGarden(
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
|
project=project,
|
||||||
|
result_arg=result_arg,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
output = llm.generate(["What is the meaning of life?", "How much is 2+2"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
assert len(output.generations) == 2
|
assert len(output.generations) == 2
|
||||||
|
|
||||||
|
|
||||||
async def test_model_garden_agenerate() -> None:
|
@pytest.mark.asyncio
|
||||||
endpoint_id = os.environ["ENDPOINT_ID"]
|
@pytest.mark.parametrize(
|
||||||
|
"endpoint_os_variable_name,result_arg",
|
||||||
|
[("FALCON_ENDPOINT_ID", "generated_text"), ("LLAMA_ENDPOINT_ID", None)],
|
||||||
|
)
|
||||||
|
async def test_model_garden_agenerate(
|
||||||
|
endpoint_os_variable_name: str, result_arg: Optional[str]
|
||||||
|
) -> None:
|
||||||
|
endpoint_id = os.environ[endpoint_os_variable_name]
|
||||||
project = os.environ["PROJECT"]
|
project = os.environ["PROJECT"]
|
||||||
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
location = "europe-west4"
|
||||||
|
llm = VertexAIModelGarden(
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
|
project=project,
|
||||||
|
result_arg=result_arg,
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
assert len(output.generations) == 2
|
assert len(output.generations) == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user