Update Vertex AI to include Gemini (#14670)

h/t to @lkuligin 
-  **Description:** added new models on VertexAI
  - **Twitter handle:** @lkuligin

---------

Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
William FH
2023-12-13 10:45:02 -08:00
committed by GitHub
parent 858f4cbce4
commit 75b8891399
6 changed files with 595 additions and 197 deletions

View File

@@ -1,17 +1,9 @@
from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Iterator,
List,
Optional,
)
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union
from langchain_core.callbacks import (
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
@@ -37,36 +29,30 @@ if TYPE_CHECKING:
TextGenerationResponse,
_LanguageModel,
)
from vertexai.preview.generative_models import Image
def _response_to_generation(
response: TextGenerationResponse,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
# This is for backwards compatibility
# We can remove after `langchain` stops importing it
_response_to_generation = None
completion_with_retry = None
stream_completion_with_retry = None
def is_codey_model(model_name: str) -> bool:
"""Returns True if the model name is a Codey model.
Args:
model_name: The model name to check.
Returns: True if the model name is a Codey model.
"""
"""Returns True if the model name is a Codey model."""
return "code" in model_name
def is_gemini_model(model_name: str) -> bool:
"""Returns True if the model name is a Gemini model."""
return model_name is not None and "gemini" in model_name
def completion_with_retry(
llm: VertexAI,
*args: Any,
prompt: List[Union[str, "Image"]],
stream: bool = False,
is_gemini: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
@@ -74,33 +60,25 @@ def completion_with_retry(
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict(*args, **kwargs)
def _completion_with_retry(
prompt: List[Union[str, "Image"]], is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return llm.client.generate_content(
prompt, stream=stream, generation_config=kwargs
)
else:
if stream:
return llm.client.predict_streaming(prompt[0], **kwargs)
return llm.client.predict(prompt[0], **kwargs)
return _completion_with_retry(*args, **kwargs)
def stream_completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict_streaming(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
return _completion_with_retry(prompt, is_gemini, **kwargs)
async def acompletion_with_retry(
llm: VertexAI,
*args: Any,
prompt: str,
is_gemini: bool = False,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
@@ -108,10 +86,16 @@ async def acompletion_with_retry(
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
return await llm.client.predict_async(*args, **kwargs)
async def _acompletion_with_retry(
prompt: str, is_gemini: bool = False, **kwargs: Any
) -> Any:
if is_gemini:
return await llm.client.generate_content_async(
prompt, generation_config=kwargs
)
return await llm.client.predict_async(prompt, **kwargs)
return await _acompletion_with_retry(*args, **kwargs)
return await _acompletion_with_retry(prompt, is_gemini, **kwargs)
class _VertexAIBase(BaseModel):
@@ -169,9 +153,13 @@ class _VertexAICommon(_VertexAIBase):
def is_codey_model(self) -> bool:
return is_codey_model(self.model_name)
@property
def _is_gemini_model(self) -> bool:
return is_gemini_model(self.model_name)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
"""Gets the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
@@ -232,9 +220,10 @@ class VertexAI(_VertexAICommon, BaseLLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
cls._try_init_vertexai(values)
tuned_model_name = values.get("tuned_model_name")
model_name = values["model_name"]
is_gemini = is_gemini_model(values["model_name"])
cls._try_init_vertexai(values)
try:
from vertexai.language_models import (
CodeGenerationModel,
@@ -247,9 +236,17 @@ class VertexAI(_VertexAICommon, BaseLLM):
TextGenerationModel as PreviewTextGenerationModel,
)
if is_gemini:
from vertexai.preview.generative_models import (
GenerativeModel,
)
if is_codey_model(model_name):
model_cls = CodeGenerationModel
preview_model_cls = PreviewCodeGenerationModel
elif is_gemini:
model_cls = GenerativeModel
preview_model_cls = GenerativeModel
else:
model_cls = TextGenerationModel
preview_model_cls = PreviewTextGenerationModel
@@ -260,8 +257,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
tuned_model_name
)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
if is_gemini:
values["client"] = model_cls(model_name=model_name)
values["client_preview"] = preview_model_cls(model_name=model_name)
else:
values["client"] = model_cls.from_pretrained(model_name)
values["client_preview"] = preview_model_cls.from_pretrained(
model_name
)
except ImportError:
raise_vertex_import_error()
@@ -288,6 +291,19 @@ class VertexAI(_VertexAICommon, BaseLLM):
return result.total_tokens
def _response_to_generation(
self, response: TextGenerationResponse
) -> GenerationChunk:
"""Converts a stream response to a generation chunk."""
try:
generation_info = {
"is_blocked": response.is_blocked,
"safety_attributes": response.safety_attributes,
}
except Exception:
generation_info = None
return GenerationChunk(text=response.text, generation_info=generation_info)
def _generate(
self,
prompts: List[str],
@@ -298,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
) -> LLMResult:
should_stream = stream if stream is not None else self.streaming
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
generations = []
generations: List[List[Generation]] = []
for prompt in prompts:
if should_stream:
generation = GenerationChunk(text="")
@@ -309,9 +325,16 @@ class VertexAI(_VertexAICommon, BaseLLM):
generations.append([generation])
else:
res = completion_with_retry(
self, prompt, run_manager=run_manager, **params
self,
[prompt],
stream=should_stream,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
generations.append([_response_to_generation(r) for r in res.candidates])
return LLMResult(generations=generations)
async def _agenerate(
@@ -325,9 +348,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
generations = []
for prompt in prompts:
res = await acompletion_with_retry(
self, prompt, run_manager=run_manager, **params
self,
prompt,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
)
generations.append(
[self._response_to_generation(r) for r in res.candidates]
)
generations.append([_response_to_generation(r) for r in res.candidates])
return LLMResult(generations=generations)
def _stream(
@@ -338,10 +367,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in stream_completion_with_retry(
self, prompt, run_manager=run_manager, **params
for stream_resp in completion_with_retry(
self,
[prompt],
stream=True,
is_gemini=self._is_gemini_model,
run_manager=run_manager,
**params,
):
chunk = _response_to_generation(stream_resp)
chunk = self._response_to_generation(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(