mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
Update Vertex AI to include Gemini (#14670)
h/t to @lkuligin - **Description:** added new models on VertexAI - **Twitter handle:** @lkuligin --------- Co-authored-by: Leonid Kuligin <lkuligin@yandex.ru> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -13,15 +13,33 @@ from langchain_core.outputs import LLMResult
|
||||
|
||||
from langchain_community.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
model_names_to_test = ["text-bison@001", "gemini-pro"]
|
||||
model_names_to_test_with_default = [None] + model_names_to_test
|
||||
|
||||
def test_vertex_initialization() -> None:
|
||||
llm = VertexAI()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_initialization(model_name: str) -> None:
|
||||
llm = VertexAI(model_name=model_name) if model_name else VertexAI()
|
||||
assert llm._llm_type == "vertexai"
|
||||
assert llm.model_name == llm.client._model_id
|
||||
try:
|
||||
assert llm.model_name == llm.client._model_id
|
||||
except AttributeError:
|
||||
assert llm.model_name == llm.client._model_name.split("/")[-1]
|
||||
|
||||
|
||||
def test_vertex_call() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_call(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
@@ -52,8 +70,16 @@ async def test_vertex_agenerate() -> None:
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_stream() -> None:
|
||||
llm = VertexAI(temperature=0)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_stream(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(temperature=0, model_name=model_name)
|
||||
if model_name
|
||||
else VertexAI(temperature=0)
|
||||
)
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
@@ -145,7 +171,11 @@ async def test_model_garden_agenerate(
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_vertex_call_count_tokens() -> None:
|
||||
llm = VertexAI()
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test,
|
||||
)
|
||||
def test_vertex_call_count_tokens(model_name: str) -> None:
|
||||
llm = VertexAI(model_name=model_name)
|
||||
output = llm.get_num_tokens("How are you?")
|
||||
assert output == 4
|
||||
|
Reference in New Issue
Block a user