Update Vertex AI to include Gemini (#14670)

h/t to @lkuligin 
-  **Description:** added new models on VertexAI
  - **Twitter handle:** @lkuligin

---------

Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
William FH
2023-12-13 10:45:02 -08:00
committed by GitHub
parent 858f4cbce4
commit 75b8891399
6 changed files with 595 additions and 197 deletions

View File

@@ -13,15 +13,33 @@ from langchain_core.outputs import LLMResult
from langchain_community.llms import VertexAI, VertexAIModelGarden
model_names_to_test = ["text-bison@001", "gemini-pro"]
model_names_to_test_with_default = [None] + model_names_to_test
def test_vertex_initialization() -> None:
llm = VertexAI()
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_initialization(model_name: str) -> None:
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id
try:
assert llm.model_name == llm.client._model_id
except AttributeError:
assert llm.model_name == llm.client._model_name.split("/")[-1]
def test_vertex_call() -> None:
llm = VertexAI(temperature=0)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_call(model_name: str) -> None:
llm = (
VertexAI(model_name=model_name, temperature=0)
if model_name
else VertexAI(temperature=0.0)
)
output = llm("Say foo:")
assert isinstance(output, str)
@@ -52,8 +70,16 @@ async def test_vertex_agenerate() -> None:
@pytest.mark.scheduled
def test_vertex_stream() -> None:
llm = VertexAI(temperature=0)
@pytest.mark.parametrize(
"model_name",
model_names_to_test_with_default,
)
def test_vertex_stream(model_name: str) -> None:
llm = (
VertexAI(temperature=0, model_name=model_name)
if model_name
else VertexAI(temperature=0)
)
outputs = list(llm.stream("Please say foo:"))
assert isinstance(outputs[0], str)
@@ -145,7 +171,11 @@ async def test_model_garden_agenerate(
assert len(output.generations) == 2
def test_vertex_call_count_tokens() -> None:
llm = VertexAI()
@pytest.mark.parametrize(
"model_name",
model_names_to_test,
)
def test_vertex_call_count_tokens(model_name: str) -> None:
llm = VertexAI(model_name=model_name)
output = llm.get_num_tokens("How are you?")
assert output == 4