mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
google-vertexai[patch]: streaming bug (#16603)
Fixes errors seen here https://github.com/langchain-ai/langchain/actions/runs/7661680517/job/20881556592#step:9:229
This commit is contained in:
parent
a989f82027
commit
5b5115c408
@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def get_generation_info(
|
def get_generation_info(
|
||||||
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
|
candidate: Union[TextGenerationResponse, Candidate],
|
||||||
|
is_gemini: bool,
|
||||||
|
*,
|
||||||
|
stream: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if is_gemini:
|
if is_gemini:
|
||||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||||
return {
|
info = {
|
||||||
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
|
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
|
||||||
"safety_ratings": [
|
"safety_ratings": [
|
||||||
{
|
{
|
||||||
"category": rating.category.name,
|
"category": rating.category.name,
|
||||||
"probability_label": rating.probability.name,
|
"probability_label": rating.probability.name,
|
||||||
|
"blocked": rating.blocked,
|
||||||
}
|
}
|
||||||
for rating in candidate.safety_ratings
|
for rating in candidate.safety_ratings
|
||||||
],
|
],
|
||||||
"citation_metadata": candidate.citation_metadata,
|
"citation_metadata": candidate.citation_metadata,
|
||||||
}
|
}
|
||||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||||
candidate_dc = dataclasses.asdict(candidate)
|
else:
|
||||||
candidate_dc.pop("text")
|
info = dataclasses.asdict(candidate)
|
||||||
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
|
info.pop("text")
|
||||||
|
info = {k: v for k, v in info.items() if not k.startswith("_")}
|
||||||
|
if stream:
|
||||||
|
# Remove non-streamable types, like bools.
|
||||||
|
info.pop("is_blocked")
|
||||||
|
return info
|
||||||
|
@ -315,10 +315,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
return result.total_tokens
|
return result.total_tokens
|
||||||
|
|
||||||
def _response_to_generation(
|
def _response_to_generation(
|
||||||
self, response: TextGenerationResponse
|
self, response: TextGenerationResponse, *, stream: bool = False
|
||||||
) -> GenerationChunk:
|
) -> GenerationChunk:
|
||||||
"""Converts a stream response to a generation chunk."""
|
"""Converts a stream response to a generation chunk."""
|
||||||
generation_info = get_generation_info(response, self._is_gemini_model)
|
generation_info = get_generation_info(
|
||||||
|
response, self._is_gemini_model, stream=stream
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
text = response.text
|
text = response.text
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -401,7 +403,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
run_manager=run_manager,
|
run_manager=run_manager,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
chunk = self._response_to_generation(stream_resp)
|
# Gemini models return GenerationResponse even when streaming, which has a
|
||||||
|
# candidates field.
|
||||||
|
stream_resp = (
|
||||||
|
stream_resp
|
||||||
|
if isinstance(stream_resp, TextGenerationResponse)
|
||||||
|
else stream_resp.candidates[0]
|
||||||
|
)
|
||||||
|
chunk = self._response_to_generation(stream_resp, stream=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
|
@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None:
|
|||||||
"model_name",
|
"model_name",
|
||||||
model_names_to_test_with_default,
|
model_names_to_test_with_default,
|
||||||
)
|
)
|
||||||
def test_vertex_call(model_name: str) -> None:
|
def test_vertex_invoke(model_name: str) -> None:
|
||||||
llm = (
|
llm = (
|
||||||
VertexAI(model_name=model_name, temperature=0)
|
VertexAI(model_name=model_name, temperature=0)
|
||||||
if model_name
|
if model_name
|
||||||
else VertexAI(temperature=0.0)
|
else VertexAI(temperature=0.0)
|
||||||
)
|
)
|
||||||
output = llm("Say foo:")
|
output = llm.invoke("Say foo:")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
model_names_to_test_with_default,
|
||||||
|
)
|
||||||
|
def test_vertex_generate(model_name: str) -> None:
|
||||||
|
llm = (
|
||||||
|
VertexAI(model_name=model_name, temperature=0)
|
||||||
|
if model_name
|
||||||
|
else VertexAI(temperature=0.0)
|
||||||
|
)
|
||||||
|
output = llm.generate(["Say foo:"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
|
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
|
||||||
def test_vertex_generate() -> None:
|
def test_vertex_generate_multiple_candidates() -> None:
|
||||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||||
output = llm.generate(["Say foo:"])
|
output = llm.generate(["Say foo:"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
|
Loading…
Reference in New Issue
Block a user