langchain[patch]: fix ChatVertexAI streaming (#14369)

This commit is contained in:
Erick Friis
2023-12-07 09:46:11 -08:00
committed by GitHub
parent db6bf8b022
commit 54040b00a4
3 changed files with 102 additions and 129 deletions

View File

@@ -242,7 +242,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
) -> Iterator[ChatGenerationChunk]:
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])
params = self._prepare_params(stop=stop, **kwargs)
params = self._prepare_params(stop=stop, stream=True, **kwargs)
examples = kwargs.get("examples", None)
if examples:
params["examples"] = _parse_examples(examples)

View File

@@ -11,7 +11,12 @@ from typing import Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import LLMResult
from langchain.chat_models import ChatVertexAI
@@ -41,6 +46,7 @@ def test_vertexai_single_call(model_name: str) -> None:
assert isinstance(response.content, str)
@pytest.mark.scheduled
def test_candidates() -> None:
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
message = HumanMessage(content="Hello")
@@ -62,6 +68,16 @@ async def test_vertexai_agenerate() -> None:
assert response.generations[0][0] == sync_response.generations[0][0]
@pytest.mark.scheduled
async def test_vertexai_stream() -> None:
model = ChatVertexAI(temperature=0)
message = HumanMessage(content="Hello")
sync_response = model.stream([message])
for chunk in sync_response:
assert isinstance(chunk, AIMessageChunk)
@pytest.mark.scheduled
def test_vertexai_single_call_with_context() -> None:
model = ChatVertexAI()