mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
Update Vertex AI to include Gemini (#14670)
h/t to @lkuligin - **Description:** added new models on VertexAI - **Twitter handle:** @lkuligin --------- Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
858f4cbce4
commit
75b8891399
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
|
||||
@ -23,8 +24,15 @@ from langchain_core.messages import (
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_community.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain_community.utilities.vertexai import raise_vertex_import_error
|
||||
from langchain_community.llms.vertexai import (
|
||||
_VertexAICommon,
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
)
|
||||
from langchain_community.utilities.vertexai import (
|
||||
load_image_from_gcs,
|
||||
raise_vertex_import_error,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vertexai.language_models import (
|
||||
@ -33,6 +41,7 @@ if TYPE_CHECKING:
|
||||
CodeChatSession,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
from vertexai.preview.generative_models import Content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -77,6 +86,55 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
||||
return chat_history
|
||||
|
||||
|
||||
def _parse_chat_history_gemini(
|
||||
history: List[BaseMessage], project: Optional[str]
|
||||
) -> List["Content"]:
|
||||
from vertexai.preview.generative_models import Content, Image, Part
|
||||
|
||||
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
||||
if isinstance(part, str):
|
||||
return Part.from_text(part)
|
||||
|
||||
if not isinstance(part, Dict):
|
||||
raise ValueError(
|
||||
f"Message's content is expected to be a dict, got {type(part)}!"
|
||||
)
|
||||
if part["type"] == "text":
|
||||
return Part.from_text(part["text"])
|
||||
elif part["type"] == "image_url":
|
||||
path = part["image_url"]["url"]
|
||||
if path.startswith("gs://"):
|
||||
image = load_image_from_gcs(path=path, project=project)
|
||||
elif path.startswith("data:image/jpeg;base64,"):
|
||||
image = Image.from_bytes(base64.b64decode(path[23:]))
|
||||
else:
|
||||
image = Image.load_from_file(path)
|
||||
else:
|
||||
raise ValueError("Only text and image_url types are supported!")
|
||||
return Part.from_image(image)
|
||||
|
||||
vertex_messages = []
|
||||
for i, message in enumerate(history):
|
||||
if i == 0 and isinstance(message, SystemMessage):
|
||||
raise ValueError("SystemMessages are not yet supported!")
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "model"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
|
||||
raw_content = message.content
|
||||
if isinstance(raw_content, str):
|
||||
raw_content = [raw_content]
|
||||
parts = [_convert_to_prompt(part) for part in raw_content]
|
||||
vertex_message = Content(role=role, parts=parts)
|
||||
vertex_messages.append(vertex_message)
|
||||
return vertex_messages
|
||||
|
||||
|
||||
def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
|
||||
from vertexai.language_models import InputOutputTextPair
|
||||
|
||||
@ -138,16 +196,25 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._try_init_vertexai(values)
|
||||
try:
|
||||
from vertexai.language_models import ChatModel, CodeChatModel
|
||||
|
||||
if is_gemini:
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
)
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
if is_gemini:
|
||||
values["client"] = GenerativeModel(model_name=values["model_name"])
|
||||
else:
|
||||
model_cls = ChatModel
|
||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
else:
|
||||
model_cls = ChatModel
|
||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
@ -181,18 +248,23 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, stream=False, **kwargs)
|
||||
examples = kwargs.get("examples") or self.examples
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
response = chat.send_message(message, generation_config=params)
|
||||
else:
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples") or self.examples
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
@ -223,18 +295,26 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
if "stream" in kwargs:
|
||||
kwargs.pop("stream")
|
||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
response = await chat.send_message_async(message, generation_config=params)
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
@ -248,15 +328,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
responses = chat.send_message(
|
||||
message, stream=True, generation_config=params
|
||||
)
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(response.text)
|
||||
|
@ -1,17 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@ -37,36 +29,30 @@ if TYPE_CHECKING:
|
||||
TextGenerationResponse,
|
||||
_LanguageModel,
|
||||
)
|
||||
from vertexai.preview.generative_models import Image
|
||||
|
||||
|
||||
def _response_to_generation(
|
||||
response: TextGenerationResponse,
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
# This is for backwards compatibility
|
||||
# We can remove after `langchain` stops importing it
|
||||
_response_to_generation = None
|
||||
completion_with_retry = None
|
||||
stream_completion_with_retry = None
|
||||
|
||||
|
||||
def is_codey_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Codey model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to check.
|
||||
|
||||
Returns: True if the model name is a Codey model.
|
||||
"""
|
||||
"""Returns True if the model name is a Codey model."""
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
def is_gemini_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Gemini model."""
|
||||
return model_name is not None and "gemini" in model_name
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
prompt: List[Union[str, "Image"]],
|
||||
stream: bool = False,
|
||||
is_gemini: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -74,33 +60,25 @@ def completion_with_retry(
|
||||
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict(*args, **kwargs)
|
||||
def _completion_with_retry(
|
||||
prompt: List[Union[str, "Image"]], is_gemini: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
prompt, stream=stream, generation_config=kwargs
|
||||
)
|
||||
else:
|
||||
if stream:
|
||||
return llm.client.predict_streaming(prompt[0], **kwargs)
|
||||
return llm.client.predict(prompt[0], **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
def stream_completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = create_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict_streaming(*args, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
return _completion_with_retry(prompt, is_gemini, **kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
prompt: str,
|
||||
is_gemini: bool = False,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -108,10 +86,16 @@ async def acompletion_with_retry(
|
||||
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return await llm.client.predict_async(*args, **kwargs)
|
||||
async def _acompletion_with_retry(
|
||||
prompt: str, is_gemini: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return await llm.client.generate_content_async(
|
||||
prompt, generation_config=kwargs
|
||||
)
|
||||
return await llm.client.predict_async(prompt, **kwargs)
|
||||
|
||||
return await _acompletion_with_retry(*args, **kwargs)
|
||||
return await _acompletion_with_retry(prompt, is_gemini, **kwargs)
|
||||
|
||||
|
||||
class _VertexAIBase(BaseModel):
|
||||
@ -169,9 +153,13 @@ class _VertexAICommon(_VertexAIBase):
|
||||
def is_codey_model(self) -> bool:
|
||||
return is_codey_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _is_gemini_model(self) -> bool:
|
||||
return is_gemini_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
"""Gets the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
@ -232,9 +220,10 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
cls._try_init_vertexai(values)
|
||||
tuned_model_name = values.get("tuned_model_name")
|
||||
model_name = values["model_name"]
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._try_init_vertexai(values)
|
||||
try:
|
||||
from vertexai.language_models import (
|
||||
CodeGenerationModel,
|
||||
@ -247,9 +236,17 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
TextGenerationModel as PreviewTextGenerationModel,
|
||||
)
|
||||
|
||||
if is_gemini:
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
)
|
||||
|
||||
if is_codey_model(model_name):
|
||||
model_cls = CodeGenerationModel
|
||||
preview_model_cls = PreviewCodeGenerationModel
|
||||
elif is_gemini:
|
||||
model_cls = GenerativeModel
|
||||
preview_model_cls = GenerativeModel
|
||||
else:
|
||||
model_cls = TextGenerationModel
|
||||
preview_model_cls = PreviewTextGenerationModel
|
||||
@ -260,8 +257,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
tuned_model_name
|
||||
)
|
||||
else:
|
||||
values["client"] = model_cls.from_pretrained(model_name)
|
||||
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
|
||||
if is_gemini:
|
||||
values["client"] = model_cls(model_name=model_name)
|
||||
values["client_preview"] = preview_model_cls(model_name=model_name)
|
||||
else:
|
||||
values["client"] = model_cls.from_pretrained(model_name)
|
||||
values["client_preview"] = preview_model_cls.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
@ -288,6 +291,19 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
|
||||
return result.total_tokens
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: TextGenerationResponse
|
||||
) -> GenerationChunk:
|
||||
"""Converts a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -298,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
) -> LLMResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
|
||||
generations = []
|
||||
generations: List[List[Generation]] = []
|
||||
for prompt in prompts:
|
||||
if should_stream:
|
||||
generation = GenerationChunk(text="")
|
||||
@ -309,9 +325,16 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
generations.append([generation])
|
||||
else:
|
||||
res = completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
self,
|
||||
[prompt],
|
||||
stream=should_stream,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
@ -325,9 +348,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
res = await acompletion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
self,
|
||||
prompt,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
@ -338,10 +367,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
for stream_resp in stream_completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
for stream_resp in completion_with_retry(
|
||||
self,
|
||||
[prompt],
|
||||
stream=True,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
):
|
||||
chunk = _response_to_generation(stream_resp)
|
||||
chunk = self._response_to_generation(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models.llms import BaseLLM, create_base_retry_decor
|
||||
if TYPE_CHECKING:
|
||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||
from google.auth.credentials import Credentials
|
||||
from vertexai.preview.generative_models import Image
|
||||
|
||||
|
||||
def create_retry_decorator(
|
||||
@ -37,7 +38,7 @@ def create_retry_decorator(
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.36.0") -> None:
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.38.0") -> None:
|
||||
"""Raise ImportError related to Vertex SDK being not available.
|
||||
|
||||
Args:
|
||||
@ -105,3 +106,19 @@ def get_client_info(module: Optional[str] = None) -> "ClientInfo":
|
||||
client_library_version=client_library_version,
|
||||
user_agent=f"langchain/{client_library_version}",
|
||||
)
|
||||
|
||||
|
||||
def load_image_from_gcs(path: str, project: Optional[str] = None) -> "Image":
|
||||
"""Loads im Image from GCS."""
|
||||
try:
|
||||
from google.cloud import storage
|
||||
except ImportError:
|
||||
raise ImportError("Could not import google-cloud-storage python package.")
|
||||
from vertexai.preview.generative_models import Image
|
||||
|
||||
gcs_client = storage.Client(project=project)
|
||||
pieces = path.split("/")
|
||||
blobs = list(gcs_client.list_blobs(pieces[2], prefix="/".join(pieces[3:])))
|
||||
if len(blobs) > 1:
|
||||
raise ValueError(f"Found more than one candidate for {path}!")
|
||||
return Image.from_bytes(blobs[0].download_as_bytes())
|
||||
|
@ -25,19 +25,24 @@ from langchain_community.chat_models.vertexai import (
|
||||
_parse_examples,
|
||||
)
|
||||
|
||||
model_names_to_test = [None, "codechat-bison", "chat-bison", "gemini-pro"]
|
||||
|
||||
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_instantiation(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
try:
|
||||
assert model.model_name == model.client._model_id
|
||||
except AttributeError:
|
||||
assert model.model_name == model.client._model_name.split("/")[-1]
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_single_call(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
@ -63,8 +68,9 @@ def test_candidates() -> None:
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_vertexai_agenerate() -> None:
|
||||
model = ChatVertexAI(temperature=0)
|
||||
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
|
||||
async def test_vertexai_agenerate(model_name: str) -> None:
|
||||
model = ChatVertexAI(temperature=0, model_name=model_name)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await model.agenerate([[message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -75,8 +81,9 @@ async def test_vertexai_agenerate() -> None:
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_vertexai_stream() -> None:
|
||||
model = ChatVertexAI(temperature=0)
|
||||
@pytest.mark.parametrize("model_name", ["chat-bison@001", "gemini-pro"])
|
||||
def test_vertexai_stream(model_name: str) -> None:
|
||||
model = ChatVertexAI(temperature=0, model_name=model_name)
|
||||
message = HumanMessage(content="Hello")
|
||||
|
||||
sync_response = model.stream([message])
|
||||
@ -101,6 +108,53 @@ def test_vertexai_single_call_with_context() -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multimodal() -> None:
|
||||
llm = ChatVertexAI(model_name="gemini-ultra-vision")
|
||||
gcs_url = (
|
||||
"gs://cloud-samples-data/generative-ai/image/"
|
||||
"320px-Felis_catus-cat_on_snow.jpg"
|
||||
)
|
||||
image_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": gcs_url},
|
||||
}
|
||||
text_message = {
|
||||
"type": "text",
|
||||
"text": "What is shown in this image?",
|
||||
}
|
||||
message = HumanMessage(content=[text_message, image_message])
|
||||
output = llm([message])
|
||||
assert isinstance(output.content, str)
|
||||
|
||||
|
||||
def test_multimodal_history() -> None:
|
||||
llm = ChatVertexAI(model_name="gemini-ultra-vision")
|
||||
gcs_url = (
|
||||
"gs://cloud-samples-data/generative-ai/image/"
|
||||
"320px-Felis_catus-cat_on_snow.jpg"
|
||||
)
|
||||
image_message = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": gcs_url},
|
||||
}
|
||||
text_message = {
|
||||
"type": "text",
|
||||
"text": "What is shown in this image?",
|
||||
}
|
||||
message1 = HumanMessage(content=[text_message, image_message])
|
||||
message2 = AIMessage(
|
||||
content=(
|
||||
"This is a picture of a cat in the snow. The cat is a tabby cat, which is "
|
||||
"a type of cat with a striped coat. The cat is standing in the snow, and "
|
||||
"its fur is covered in snow."
|
||||
)
|
||||
)
|
||||
message3 = HumanMessage(content="What time of day is it?")
|
||||
response = llm([message1, message2, message3])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertexai_single_call_with_examples() -> None:
|
||||
model = ChatVertexAI()
|
||||
@ -117,7 +171,7 @@ def test_vertexai_single_call_with_examples() -> None:
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||
@pytest.mark.parametrize("model_name", model_names_to_test)
|
||||
def test_vertexai_single_call_with_history(model_name: str) -> None:
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
|
@ -13,15 +13,33 @@ from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
model_names_to_test = ["text-bison@001", "gemini-pro"]
|
||||
model_names_to_test_with_default = [None] + model_names_to_test
|
||||
|
||||
def test_vertex_initialization() -> None:
|
||||
llm = VertexAI()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_initialization(model_name: str) -> None:
|
||||
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
|
||||
assert llm._llm_type == "vertexai"
|
||||
assert llm.model_name == llm.client._model_id
|
||||
try:
|
||||
assert llm.model_name == llm.client._model_id
|
||||
except AttributeError:
|
||||
assert llm.model_name == llm.client._model_name.split("/")[-1]
|
||||
|
||||
|
||||
def test_vertex_call() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_call(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
@ -52,8 +70,16 @@ async def test_vertex_agenerate() -> None:
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_stream() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_stream(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(temperature=0, model_name=model_name)
|
||||
if model_name
|
||||
else VertexAI(temperature=0)
|
||||
)
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
@ -145,7 +171,11 @@ async def test_model_garden_agenerate(
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_vertex_call_count_tokens() -> None:
|
||||
llm = VertexAI()
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test,
|
||||
)
|
||||
def test_vertex_call_count_tokens(model_name: str) -> None:
|
||||
llm = VertexAI(model_name=model_name)
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
Loading…
Reference in New Issue
Block a user