mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
google-vertexai[patch]: streaming bug (#16603)
Fixes errors seen here https://github.com/langchain-ai/langchain/actions/runs/7661680517/job/20881556592#step:9:229
This commit is contained in:
parent
a989f82027
commit
5b5115c408
@ -97,22 +97,31 @@ def is_gemini_model(model_name: str) -> bool:
|
||||
|
||||
|
||||
def get_generation_info(
|
||||
candidate: Union[TextGenerationResponse, Candidate], is_gemini: bool
|
||||
candidate: Union[TextGenerationResponse, Candidate],
|
||||
is_gemini: bool,
|
||||
*,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
if is_gemini:
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#response_body
|
||||
return {
|
||||
info = {
|
||||
"is_blocked": any([rating.blocked for rating in candidate.safety_ratings]),
|
||||
"safety_ratings": [
|
||||
{
|
||||
"category": rating.category.name,
|
||||
"probability_label": rating.probability.name,
|
||||
"blocked": rating.blocked,
|
||||
}
|
||||
for rating in candidate.safety_ratings
|
||||
],
|
||||
"citation_metadata": candidate.citation_metadata,
|
||||
}
|
||||
# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat#response_body
|
||||
candidate_dc = dataclasses.asdict(candidate)
|
||||
candidate_dc.pop("text")
|
||||
return {k: v for k, v in candidate_dc.items() if not k.startswith("_")}
|
||||
else:
|
||||
info = dataclasses.asdict(candidate)
|
||||
info.pop("text")
|
||||
info = {k: v for k, v in info.items() if not k.startswith("_")}
|
||||
if stream:
|
||||
# Remove non-streamable types, like bools.
|
||||
info.pop("is_blocked")
|
||||
return info
|
||||
|
@ -315,10 +315,12 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
return result.total_tokens
|
||||
|
||||
def _response_to_generation(
|
||||
self, response: TextGenerationResponse
|
||||
self, response: TextGenerationResponse, *, stream: bool = False
|
||||
) -> GenerationChunk:
|
||||
"""Converts a stream response to a generation chunk."""
|
||||
generation_info = get_generation_info(response, self._is_gemini_model)
|
||||
generation_info = get_generation_info(
|
||||
response, self._is_gemini_model, stream=stream
|
||||
)
|
||||
try:
|
||||
text = response.text
|
||||
except AttributeError:
|
||||
@ -401,7 +403,14 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
run_manager=run_manager,
|
||||
**params,
|
||||
):
|
||||
chunk = self._response_to_generation(stream_resp)
|
||||
# Gemini models return GenerationResponse even when streaming, which has a
|
||||
# candidates field.
|
||||
stream_resp = (
|
||||
stream_resp
|
||||
if isinstance(stream_resp, TextGenerationResponse)
|
||||
else stream_resp.candidates[0]
|
||||
)
|
||||
chunk = self._response_to_generation(stream_resp, stream=True)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
|
@ -32,18 +32,33 @@ def test_vertex_initialization(model_name: str) -> None:
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_call(model_name: str) -> None:
|
||||
def test_vertex_invoke(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
model_names_to_test_with_default,
|
||||
)
|
||||
def test_vertex_generate(model_name: str) -> None:
|
||||
llm = (
|
||||
VertexAI(model_name=model_name, temperature=0)
|
||||
if model_name
|
||||
else VertexAI(temperature=0.0)
|
||||
)
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="VertexAI doesn't always respect number of candidates")
|
||||
def test_vertex_generate() -> None:
|
||||
def test_vertex_generate_multiple_candidates() -> None:
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
Loading…
Reference in New Issue
Block a user