Update Vertex AI to include Gemini (#14670)

h/t to @lkuligin 
-  **Description:** added new models on VertexAI
  - **Twitter handle:** @lkuligin

---------

Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
William FH 2023-12-13 10:45:02 -08:00 committed by GitHub
parent 858f4cbce4
commit 75b8891399
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 595 additions and 197 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
"""Wrapper around Google VertexAI chat-based models.""" """Wrapper around Google VertexAI chat-based models."""
from __future__ import annotations from __future__ import annotations
import base64
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
@ -23,8 +24,15 @@ from langchain_core.messages import (
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import root_validator
from langchain_community.llms.vertexai import _VertexAICommon, is_codey_model from langchain_community.llms.vertexai import (
from langchain_community.utilities.vertexai import raise_vertex_import_error _VertexAICommon,
is_codey_model,
is_gemini_model,
)
from langchain_community.utilities.vertexai import (
load_image_from_gcs,
raise_vertex_import_error,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vertexai.language_models import ( from vertexai.language_models import (
@ -33,6 +41,7 @@ if TYPE_CHECKING:
CodeChatSession, CodeChatSession,
InputOutputTextPair, InputOutputTextPair,
) )
from vertexai.preview.generative_models import Content
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,6 +86,55 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
return chat_history return chat_history
def _parse_chat_history_gemini(
history: List[BaseMessage], project: Optional[str]
) -> List["Content"]:
from vertexai.preview.generative_models import Content, Image, Part
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
if not isinstance(part, Dict):
raise ValueError(
f"Message's content is expected to be a dict, got {type(part)}!"
)
if part["type"] == "text":
return Part.from_text(part["text"])
elif part["type"] == "image_url":
path = part["image_url"]["url"]
if path.startswith("gs://"):
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/jpeg;base64,"):
image = Image.from_bytes(base64.b64decode(path[23:]))
else:
image = Image.load_from_file(path)
else:
raise ValueError("Only text and image_url types are supported!")
return Part.from_image(image)
vertex_messages = []
for i, message in enumerate(history):
if i == 0 and isinstance(message, SystemMessage):
raise ValueError("SystemMessages are not yet supported!")
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)
raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
parts = [_convert_to_prompt(part) for part in raw_content]
vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]: def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
from vertexai.language_models import InputOutputTextPair from vertexai.language_models import InputOutputTextPair
@ -138,16 +196,25 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""
is_gemini = is_gemini_model(values["model_name"])
cls._try_init_vertexai(values) cls._try_init_vertexai(values)
try: try:
from vertexai.language_models import ChatModel, CodeChatModel from vertexai.language_models import ChatModel, CodeChatModel
if is_gemini:
from vertexai.preview.generative_models import (
GenerativeModel,
)
except ImportError: except ImportError:
raise_vertex_import_error() raise_vertex_import_error()
if is_codey_model(values["model_name"]): if is_gemini:
model_cls = CodeChatModel values["client"] = GenerativeModel(model_name=values["model_name"])
else: else:
model_cls = ChatModel if is_codey_model(values["model_name"]):
values["client"] = model_cls.from_pretrained(values["model_name"]) model_cls = CodeChatModel
else:
model_cls = ChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
return values return values
def _generate( def _generate(
@ -181,18 +248,23 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
question = _get_question(messages) question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = self._prepare_params(stop=stop, stream=False, **kwargs) params = self._prepare_params(stop=stop, stream=False, **kwargs)
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
msg_params = {} msg_params = {}
if "candidate_count" in params: if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count") msg_params["candidate_count"] = params.pop("candidate_count")
chat = self._start_chat(history, **params) if self._is_gemini_model:
response = chat.send_message(question.content, **msg_params) history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = chat.send_message(message, generation_config=params)
else:
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples") or self.examples
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = chat.send_message(question.content, **msg_params)
generations = [ generations = [
ChatGeneration(message=AIMessage(content=r.text)) ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates for r in response.candidates
@ -223,18 +295,26 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
if "stream" in kwargs: if "stream" in kwargs:
kwargs.pop("stream") kwargs.pop("stream")
logger.warning("ChatVertexAI does not currently support async streaming.") logger.warning("ChatVertexAI does not currently support async streaming.")
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = self._prepare_params(stop=stop, **kwargs)
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
params = self._prepare_params(stop=stop, **kwargs)
msg_params = {} msg_params = {}
if "candidate_count" in params: if "candidate_count" in params:
msg_params["candidate_count"] = params.pop("candidate_count") msg_params["candidate_count"] = params.pop("candidate_count")
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params) if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
response = await chat.send_message_async(message, generation_config=params)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
response = await chat.send_message_async(question.content, **msg_params)
generations = [ generations = [
ChatGeneration(message=AIMessage(content=r.text)) ChatGeneration(message=AIMessage(content=r.text))
for r in response.candidates for r in response.candidates
@ -248,15 +328,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = self._prepare_params(stop=stop, stream=True, **kwargs) params = self._prepare_params(stop=stop, stream=True, **kwargs)
examples = kwargs.get("examples", None) if self._is_gemini_model:
if examples: history_gemini = _parse_chat_history_gemini(messages, project=self.project)
params["examples"] = _parse_examples(examples) message = history_gemini.pop()
chat = self.client.start_chat(history=history_gemini)
chat = self._start_chat(history, **params) responses = chat.send_message(
responses = chat.send_message_streaming(question.content, **params) message, stream=True, generation_config=params
)
else:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses: for response in responses:
if run_manager: if run_manager:
run_manager.on_llm_new_token(response.text) run_manager.on_llm_new_token(response.text)

View File

@ -1,17 +1,9 @@
from __future__ import annotations from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from typing import ( from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Iterator,
List,
Optional,
)
from langchain_core.callbacks import ( from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
@ -37,36 +29,30 @@ if TYPE_CHECKING:
TextGenerationResponse, TextGenerationResponse,
_LanguageModel, _LanguageModel,
) )
from vertexai.preview.generative_models import Image
# This is for backwards compatibility
def _response_to_generation( # We can remove after `langchain` stops importing it
response: TextGenerationResponse, _response_to_generation = None
) -> GenerationChunk: completion_with_retry = None
"""Convert a stream response to a generation chunk.""" stream_completion_with_retry = None
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def is_codey_model(model_name: str) -> bool: def is_codey_model(model_name: str) -> bool:
"""Returns True if the model name is a Codey model. """Returns True if the model name is a Codey model."""
Args:
model_name: The model name to check.
Returns: True if the model name is a Codey model.
"""
return "code" in model_name return "code" in model_name
def is_gemini_model(model_name: str) -> bool:
"""Returns True if the model name is a Gemini model."""
return model_name is not None and "gemini" in model_name
def completion_with_retry( def completion_with_retry(
llm: VertexAI, llm: VertexAI,
*args: Any, prompt: List[Union[str, "Image"]],
stream: bool = False,
is_gemini: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -74,33 +60,25 @@ def completion_with_retry(
retry_decorator = create_retry_decorator(llm, run_manager=run_manager) retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: def _completion_with_retry(
return llm.client.predict(*args, **kwargs) prompt: List[Union[str, "Image"]], is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return llm.client.generate_content(
prompt, stream=stream, generation_config=kwargs
)
else:
if stream:
return llm.client.predict_streaming(prompt[0], **kwargs)
return llm.client.predict(prompt[0], **kwargs)
return _completion_with_retry(*args, **kwargs) return _completion_with_retry(prompt, is_gemini, **kwargs)
def stream_completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict_streaming(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
async def acompletion_with_retry( async def acompletion_with_retry(
llm: VertexAI, llm: VertexAI,
*args: Any, prompt: str,
is_gemini: bool = False,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -108,10 +86,16 @@ async def acompletion_with_retry(
retry_decorator = create_retry_decorator(llm, run_manager=run_manager) retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator @retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any: async def _acompletion_with_retry(
return await llm.client.predict_async(*args, **kwargs) prompt: str, is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return await llm.client.generate_content_async(
prompt, generation_config=kwargs
)
return await llm.client.predict_async(prompt, **kwargs)
return await _acompletion_with_retry(*args, **kwargs) return await _acompletion_with_retry(prompt, is_gemini, **kwargs)
class _VertexAIBase(BaseModel): class _VertexAIBase(BaseModel):
@ -169,9 +153,13 @@ class _VertexAICommon(_VertexAIBase):
def is_codey_model(self) -> bool: def is_codey_model(self) -> bool:
return is_codey_model(self.model_name) return is_codey_model(self.model_name)
@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_name)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Gets the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params} return {**{"model_name": self.model_name}, **self._default_params}
@property @property
@ -232,9 +220,10 @@ class VertexAI(_VertexAICommon, BaseLLM):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""
cls._try_init_vertexai(values)
tuned_model_name = values.get("tuned_model_name") tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"] model_name = values["model_name"]
is_gemini = is_gemini_model(values["model_name"])
cls._try_init_vertexai(values)
try: try:
from vertexai.language_models import ( from vertexai.language_models import (
CodeGenerationModel, CodeGenerationModel,
@ -247,9 +236,17 @@ class VertexAI(_VertexAICommon, BaseLLM):
TextGenerationModel as PreviewTextGenerationModel, TextGenerationModel as PreviewTextGenerationModel,
) )
if is_gemini:
from vertexai.preview.generative_models import (
GenerativeModel,
)
if is_codey_model(model_name): if is_codey_model(model_name):
model_cls = CodeGenerationModel model_cls = CodeGenerationModel
preview_model_cls = PreviewCodeGenerationModel preview_model_cls = PreviewCodeGenerationModel
elif is_gemini:
model_cls = GenerativeModel
preview_model_cls = GenerativeModel
else: else:
model_cls = TextGenerationModel model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel preview_model_cls = PreviewTextGenerationModel
@ -260,8 +257,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
tuned_model_name tuned_model_name
) )
else: else:
values["client"] = model_cls.from_pretrained(model_name) if is_gemini:
values["client_preview"] = preview_model_cls.from_pretrained(model_name) values["client"] = model_cls(model_name=model_name)
values["client_preview"] = preview_model_cls(model_name=model_name)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(
model_name
)
except ImportError: except ImportError:
raise_vertex_import_error() raise_vertex_import_error()
@ -288,6 +291,19 @@ class VertexAI(_VertexAICommon, BaseLLM):
return result.total_tokens return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def _generate( def _generate(
self, self,
prompts: List[str], prompts: List[str],
@ -298,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
) -> LLMResult: ) -> LLMResult:
should_stream = stream if stream is not None else self.streaming should_stream = stream if stream is not None else self.streaming
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs) params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
generations = [] generations: List[List[Generation]] = []
for prompt in prompts: for prompt in prompts:
if should_stream: if should_stream:
generation = GenerationChunk(text="") generation = GenerationChunk(text="")
@ -309,9 +325,16 @@ class VertexAI(_VertexAICommon, BaseLLM):
generations.append([generation]) generations.append([generation])
else: else:
res = completion_with_retry( res = completion_with_retry(
self, prompt, run_manager=run_manager, **params self,
[prompt],
stream=should_stream,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
) )
generations.append([_response_to_generation(r) for r in res.candidates])
return LLMResult(generations=generations) return LLMResult(generations=generations)
async def _agenerate( async def _agenerate(
@ -325,9 +348,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
generations = [] generations = []
for prompt in prompts: for prompt in prompts:
res = await acompletion_with_retry( res = await acompletion_with_retry(
self, prompt, run_manager=run_manager, **params self,
prompt,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
) )
generations.append([_response_to_generation(r) for r in res.candidates])
return LLMResult(generations=generations) return LLMResult(generations=generations)
def _stream( def _stream(
@ -338,10 +367,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> Iterator[GenerationChunk]: ) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs) params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in stream_completion_with_retry( for stream_resp in completion_with_retry(
self, prompt, run_manager=run_manager, **params self,
[prompt],
stream=True,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
): ):
chunk = _response_to_generation(stream_resp) chunk = self._response_to_generation(stream_resp)
yield chunk yield chunk
if run_manager: if run_manager:
run_manager.on_llm_new_token( run_manager.on_llm_new_token(

View File

@ -11,6 +11,7 @@ from langchain_core.language_models.llms import BaseLLM, create_base_retry_decor
if TYPE_CHECKING: if TYPE_CHECKING:
from google.api_core.gapic_v1.client_info import ClientInfo from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import Credentials from google.auth.credentials import Credentials
from vertexai.preview.generative_models import Image
def create_retry_decorator( def create_retry_decorator(
@ -37,7 +38,7 @@ def create_retry_decorator(
return decorator return decorator
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None: def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
"""Raise ImportError related to Vertex SDK being not available. """Raise ImportError related to Vertex SDK being not available.
Args: Args:
@ -105,3 +106,19 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
client_library_version=client_library_version, client_library_version=client_library_version,
user_agent=f"langchain/{client_library_version}", user_agent=f"langchain/{client_library_version}",
) )
def load_image_from_gcs(path: str, project: Optional[str] = None) -> "Image":
"""Loads im Image from GCS."""
try:
from google.cloud import storage
except ImportError:
raise ImportError("Could not import google-cloud-storage python package.")
from vertexai.preview.generative_models import Image
gcs_client = storage.Client(project=project)
pieces = path.split("/")
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
if len(blobs) > 1:
raise ValueError(f"Found more than one candidate for {path}!")
return Image.from_bytes(blobs[0].download_as_bytes())

View File

@ -25,19 +25,24 @@ from langchain_community.chat_models.vertexai import (
_parse_examples, _parse_examples,
) )
model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_instantiation(model_name: str) -> None: def test_vertexai_instantiation(model_name: str) -> None:
if model_name: if model_name:
model = ChatVertexAI(model_name=model_name) model = ChatVertexAI(model_name=model_name)
else: else:
model = ChatVertexAI() model = ChatVertexAI()
assert model._llm_type == "vertexai" assert model._llm_type == "vertexai"
assert model.model_name == model.client._model_id try:
assert model.model_name == model.client._model_id
except AttributeError:
assert model.model_name == model.client._model_name.split("/")[-1]
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"]) @pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call(model_name: str) -> None: def test_vertexai_single_call(model_name: str) -> None:
if model_name: if model_name:
model = ChatVertexAI(model_name=model_name) model = ChatVertexAI(model_name=model_name)
@ -63,8 +68,9 @@ def test_candidates() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
async def test_vertexai_agenerate() -> None: @pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
model = ChatVertexAI(temperature=0) async def test_vertexai_agenerate(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
response = await model.agenerate([[message]]) response = await model.agenerate([[message]])
assert isinstance(response, LLMResult) assert isinstance(response, LLMResult)
@ -75,8 +81,9 @@ async def test_vertexai_agenerate() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
async def test_vertexai_stream() -> None: @pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
model = ChatVertexAI(temperature=0) def test_vertexai_stream(model_name: str) -> None:
model = ChatVertexAI(temperature=0, model_name=model_name)
message = HumanMessage(content="Hello") message = HumanMessage(content="Hello")
sync_response = model.stream([message]) sync_response = model.stream([message])
@ -101,6 +108,53 @@ def test_vertexai_single_call_with_context() -> None:
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_multimodal() -> None:
llm = ChatVertexAI(model_name="gemini-ultra-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
)
image_message = {
"type": "image_url",
"image_url": {"url": gcs_url},
}
text_message = {
"type": "text",
"text": "What is shown in this image?",
}
message = HumanMessage(content=[text_message, image_message])
output = llm([message])
assert isinstance(output.content, str)
def test_multimodal_history() -> None:
llm = ChatVertexAI(model_name="gemini-ultra-vision")
gcs_url = (
"gs://cloud-samples-data/generative-ai/image/"
"320px-Felis_catus-cat_on_snow.jpg"
)
image_message = {
"type": "image_url",
"image_url": {"url": gcs_url},
}
text_message = {
"type": "text",
"text": "What is shown in this image?",
}
message1 = HumanMessage(content=[text_message, image_message])
message2 = AIMessage(
content=(
"This is a picture of a cat in the snow. The cat is a tabby cat, which is "
"a type of cat with a striped coat. The cat is standing in the snow, and "
"its fur is covered in snow."
)
)
message3 = HumanMessage(content="What time of day is it?")
response = llm([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
@pytest.mark.scheduled @pytest.mark.scheduled
def test_vertexai_single_call_with_examples() -> None: def test_vertexai_single_call_with_examples() -> None:
model = ChatVertexAI() model = ChatVertexAI()
@ -117,7 +171,7 @@ def test_vertexai_single_call_with_examples() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"]) @pytest.mark.parametrize("model_name", model_names_to_test)
def test_vertexai_single_call_with_history(model_name: str) -> None: def test_vertexai_single_call_with_history(model_name: str) -> None:
if model_name: if model_name:
model = ChatVertexAI(model_name=model_name) model = ChatVertexAI(model_name=model_name)

View File

@ -13,15 +13,33 @@ from langchain_core.outputs import LLMResult
from langchain_community.llms import VertexAI, VertexAIModelGarden from langchain_community.llms import VertexAI, VertexAIModelGarden
model_names_to_test = ["text-bison@001", "gemini-pro"]
model_names_to_test_with_default = [None] + model_names_to_test
def test_vertex_initialization() -> None:
llm = VertexAI() @pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_initialization(model_name: str) -> None:
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
assert llm._llm_type == "vertexai" assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id try:
assert llm.model_name == llm.client._model_id
except AttributeError:
assert llm.model_name == llm.client._model_name.split("/")[-1]
def test_vertex_call() -> None: @pytest.mark.parametrize(
llm = VertexAI(temperature=0) "model_name",
model_names_to_test_with_default,
)
def test_vertex_call(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm("Say foo:") output = llm("Say foo:")
assert isinstance(output, str) assert isinstance(output, str)
@ -52,8 +70,16 @@ async def test_vertex_agenerate() -> None:
@pytest.mark.scheduled @pytest.mark.scheduled
def test_vertex_stream() -> None: @pytest.mark.parametrize(
llm = VertexAI(temperature=0) "model_name",
model_names_to_test_with_default,
)
def test_vertex_stream(model_name: str) -> None:
llm = (
VertexAI(temperature=0, model_name=model_name)
if model_name
else VertexAI(temperature=0)
)
outputs = list(llm.stream("Please say foo:")) outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str) assert isinstance(outputs[0], str)
@ -145,7 +171,11 @@ async def test_model_garden_agenerate(
assert len(output.generations) == 2 assert len(output.generations) == 2
def test_vertex_call_count_tokens() -> None: @pytest.mark.parametrize(
llm = VertexAI() "model_name",
model_names_to_test,
)
def test_vertex_call_count_tokens(model_name: str) -> None:
llm = VertexAI(model_name=model_name)
output = llm.get_num_tokens("How are you?") output = llm.get_num_tokens("How are you?")
assert output == 4 assert output == 4