mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 12:07:36 +00:00
Update Vertex AI to include Gemini (#14670)
h/t to @lkuligin - **Description:** added new models on VertexAI - **Twitter handle:** @lkuligin --------- Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Wrapper around Google VertexAI chat-based models."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast
|
||||
@@ -23,8 +24,15 @@ from langchain_core.messages import (
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
from langchain_community.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain_community.utilities.vertexai import raise_vertex_import_error
|
||||
from langchain_community.llms.vertexai import (
|
||||
_VertexAICommon,
|
||||
is_codey_model,
|
||||
is_gemini_model,
|
||||
)
|
||||
from langchain_community.utilities.vertexai import (
|
||||
load_image_from_gcs,
|
||||
raise_vertex_import_error,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vertexai.language_models import (
|
||||
@@ -33,6 +41,7 @@ if TYPE_CHECKING:
|
||||
CodeChatSession,
|
||||
InputOutputTextPair,
|
||||
)
|
||||
from vertexai.preview.generative_models import Content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -77,6 +86,55 @@ def _parse_chat_history(history: List[BaseMessage]) -> _ChatHistory:
|
||||
return chat_history
|
||||
|
||||
|
||||
def _parse_chat_history_gemini(
|
||||
history: List[BaseMessage], project: Optional[str]
|
||||
) -> List["Content"]:
|
||||
from vertexai.preview.generative_models import Content, Image, Part
|
||||
|
||||
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
|
||||
if isinstance(part, str):
|
||||
return Part.from_text(part)
|
||||
|
||||
if not isinstance(part, Dict):
|
||||
raise ValueError(
|
||||
f"Message's content is expected to be a dict, got {type(part)}!"
|
||||
)
|
||||
if part["type"] == "text":
|
||||
return Part.from_text(part["text"])
|
||||
elif part["type"] == "image_url":
|
||||
path = part["image_url"]["url"]
|
||||
if path.startswith("gs://"):
|
||||
image = load_image_from_gcs(path=path, project=project)
|
||||
elif path.startswith("data:image/jpeg;base64,"):
|
||||
image = Image.from_bytes(base64.b64decode(path[23:]))
|
||||
else:
|
||||
image = Image.load_from_file(path)
|
||||
else:
|
||||
raise ValueError("Only text and image_url types are supported!")
|
||||
return Part.from_image(image)
|
||||
|
||||
vertex_messages = []
|
||||
for i, message in enumerate(history):
|
||||
if i == 0 and isinstance(message, SystemMessage):
|
||||
raise ValueError("SystemMessages are not yet supported!")
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "model"
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
)
|
||||
|
||||
raw_content = message.content
|
||||
if isinstance(raw_content, str):
|
||||
raw_content = [raw_content]
|
||||
parts = [_convert_to_prompt(part) for part in raw_content]
|
||||
vertex_message = Content(role=role, parts=parts)
|
||||
vertex_messages.append(vertex_message)
|
||||
return vertex_messages
|
||||
|
||||
|
||||
def _parse_examples(examples: List[BaseMessage]) -> List["InputOutputTextPair"]:
|
||||
from vertexai.language_models import InputOutputTextPair
|
||||
|
||||
@@ -138,16 +196,25 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._try_init_vertexai(values)
|
||||
try:
|
||||
from vertexai.language_models import ChatModel, CodeChatModel
|
||||
|
||||
if is_gemini:
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
)
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
if is_gemini:
|
||||
values["client"] = GenerativeModel(model_name=values["model_name"])
|
||||
else:
|
||||
model_cls = ChatModel
|
||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||
if is_codey_model(values["model_name"]):
|
||||
model_cls = CodeChatModel
|
||||
else:
|
||||
model_cls = ChatModel
|
||||
values["client"] = model_cls.from_pretrained(values["model_name"])
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
@@ -181,18 +248,23 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, stream=False, **kwargs)
|
||||
examples = kwargs.get("examples") or self.examples
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
response = chat.send_message(message, generation_config=params)
|
||||
else:
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples") or self.examples
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
@@ -223,18 +295,26 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
if "stream" in kwargs:
|
||||
kwargs.pop("stream")
|
||||
logger.warning("ChatVertexAI does not currently support async streaming.")
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
response = await chat.send_message_async(message, generation_config=params)
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
@@ -248,15 +328,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
if self._is_gemini_model:
|
||||
history_gemini = _parse_chat_history_gemini(messages, project=self.project)
|
||||
message = history_gemini.pop()
|
||||
chat = self.client.start_chat(history=history_gemini)
|
||||
responses = chat.send_message(
|
||||
message, stream=True, generation_config=params
|
||||
)
|
||||
else:
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(response.text)
|
||||
|
Reference in New Issue
Block a user