mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 21:12:48 +00:00
Update Vertex AI to include Gemini (#14670)
h/t to @lkuligin - **Description:** added new models on VertexAI - **Twitter handle:** @lkuligin --------- Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -1,17 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
@@ -37,36 +29,30 @@ if TYPE_CHECKING:
|
||||
TextGenerationResponse,
|
||||
_LanguageModel,
|
||||
)
|
||||
from vertexai.preview.generative_models import Image
|
||||
|
||||
|
||||
def _response_to_generation(
|
||||
response: TextGenerationResponse,
|
||||
) -> GenerationChunk:
|
||||
"""Convert a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
# This is for backwards compatibility
|
||||
# We can remove after `langchain` stops importing it
|
||||
_response_to_generation = None
|
||||
completion_with_retry = None
|
||||
stream_completion_with_retry = None
|
||||
|
||||
|
||||
def is_codey_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Codey model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to check.
|
||||
|
||||
Returns: True if the model name is a Codey model.
|
||||
"""
|
||||
"""Returns True if the model name is a Codey model."""
|
||||
return "code" in model_name
|
||||
|
||||
|
||||
def is_gemini_model(model_name: str) -> bool:
|
||||
"""Returns True if the model name is a Gemini model."""
|
||||
return model_name is not None and "gemini" in model_name
|
||||
|
||||
|
||||
def completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
prompt: List[Union[str, "Image"]],
|
||||
stream: bool = False,
|
||||
is_gemini: bool = False,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -74,33 +60,25 @@ def completion_with_retry(
|
||||
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict(*args, **kwargs)
|
||||
def _completion_with_retry(
|
||||
prompt: List[Union[str, "Image"]], is_gemini: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return llm.client.generate_content(
|
||||
prompt, stream=stream, generation_config=kwargs
|
||||
)
|
||||
else:
|
||||
if stream:
|
||||
return llm.client.predict_streaming(prompt[0], **kwargs)
|
||||
return llm.client.predict(prompt[0], **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
|
||||
|
||||
def stream_completion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = create_retry_decorator(
|
||||
llm, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return llm.client.predict_streaming(*args, **kwargs)
|
||||
|
||||
return _completion_with_retry(*args, **kwargs)
|
||||
return _completion_with_retry(prompt, is_gemini, **kwargs)
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: VertexAI,
|
||||
*args: Any,
|
||||
prompt: str,
|
||||
is_gemini: bool = False,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -108,10 +86,16 @@ async def acompletion_with_retry(
|
||||
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
||||
return await llm.client.predict_async(*args, **kwargs)
|
||||
async def _acompletion_with_retry(
|
||||
prompt: str, is_gemini: bool = False, **kwargs: Any
|
||||
) -> Any:
|
||||
if is_gemini:
|
||||
return await llm.client.generate_content_async(
|
||||
prompt, generation_config=kwargs
|
||||
)
|
||||
return await llm.client.predict_async(prompt, **kwargs)
|
||||
|
||||
return await _acompletion_with_retry(*args, **kwargs)
|
||||
return await _acompletion_with_retry(prompt, is_gemini, **kwargs)
|
||||
|
||||
|
||||
class _VertexAIBase(BaseModel):
|
||||
@@ -169,9 +153,13 @@ class _VertexAICommon(_VertexAIBase):
|
||||
def is_codey_model(self) -> bool:
|
||||
return is_codey_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _is_gemini_model(self) -> bool:
|
||||
return is_gemini_model(self.model_name)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
"""Gets the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
@@ -232,9 +220,10 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
cls._try_init_vertexai(values)
|
||||
tuned_model_name = values.get("tuned_model_name")
|
||||
model_name = values["model_name"]
|
||||
is_gemini = is_gemini_model(values["model_name"])
|
||||
cls._try_init_vertexai(values)
|
||||
try:
|
||||
from vertexai.language_models import (
|
||||
CodeGenerationModel,
|
||||
@@ -247,9 +236,17 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
TextGenerationModel as PreviewTextGenerationModel,
|
||||
)
|
||||
|
||||
if is_gemini:
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerativeModel,
|
||||
)
|
||||
|
||||
if is_codey_model(model_name):
|
||||
model_cls = CodeGenerationModel
|
||||
preview_model_cls = PreviewCodeGenerationModel
|
||||
elif is_gemini:
|
||||
model_cls = GenerativeModel
|
||||
preview_model_cls = GenerativeModel
|
||||
else:
|
||||
model_cls = TextGenerationModel
|
||||
preview_model_cls = PreviewTextGenerationModel
|
||||
@@ -260,8 +257,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
tuned_model_name
|
||||
)
|
||||
else:
|
||||
values["client"] = model_cls.from_pretrained(model_name)
|
||||
values["client_preview"] = preview_model_cls.from_pretrained(model_name)
|
||||
if is_gemini:
|
||||
values["client"] = model_cls(model_name=model_name)
|
||||
values["client_preview"] = preview_model_cls(model_name=model_name)
|
||||
else:
|
||||
values["client"] = model_cls.from_pretrained(model_name)
|
||||
values["client_preview"] = preview_model_cls.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
@@ -288,6 +291,19 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
|
||||
return result.total_tokens
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: TextGenerationResponse
|
||||
) -> GenerationChunk:
|
||||
"""Converts a stream response to a generation chunk."""
|
||||
try:
|
||||
generation_info = {
|
||||
"is_blocked": response.is_blocked,
|
||||
"safety_attributes": response.safety_attributes,
|
||||
}
|
||||
except Exception:
|
||||
generation_info = None
|
||||
return GenerationChunk(text=response.text, generation_info=generation_info)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -298,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
) -> LLMResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
|
||||
generations = []
|
||||
generations: List[List[Generation]] = []
|
||||
for prompt in prompts:
|
||||
if should_stream:
|
||||
generation = GenerationChunk(text="")
|
||||
@@ -309,9 +325,16 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
generations.append([generation])
|
||||
else:
|
||||
res = completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
self,
|
||||
[prompt],
|
||||
stream=should_stream,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
@@ -325,9 +348,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
res = await acompletion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
self,
|
||||
prompt,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[self._response_to_generation(r) for r in res.candidates]
|
||||
)
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
@@ -338,10 +367,15 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
for stream_resp in stream_completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
for stream_resp in completion_with_retry(
|
||||
self,
|
||||
[prompt],
|
||||
stream=True,
|
||||
is_gemini=self._is_gemini_model,
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
):
|
||||
chunk = _response_to_generation(stream_resp)
|
||||
chunk = self._response_to_generation(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
|
Reference in New Issue
Block a user