mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
added candidate_count for Vertex models (#11729)
- **Description:** added support for `candidate_count` parameter on Vertex
This commit is contained in:
parent
9d200e6cbe
commit
9f0a718198
@ -12,10 +12,7 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
||||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||||
from langchain.pydantic_v1 import root_validator
|
from langchain.pydantic_v1 import root_validator
|
||||||
from langchain.schema import (
|
from langchain.schema import ChatGeneration, ChatResult
|
||||||
ChatGeneration,
|
|
||||||
ChatResult,
|
|
||||||
)
|
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -177,16 +174,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
|
|
||||||
question = _get_question(messages)
|
question = _get_question(messages)
|
||||||
history = _parse_chat_history(messages[:-1])
|
history = _parse_chat_history(messages[:-1])
|
||||||
params = self._prepare_params(stop=stop, **kwargs)
|
params = self._prepare_params(stop=stop, stream=False, **kwargs)
|
||||||
examples = kwargs.get("examples", None)
|
examples = kwargs.get("examples", None)
|
||||||
if examples:
|
if examples:
|
||||||
params["examples"] = _parse_examples(examples)
|
params["examples"] = _parse_examples(examples)
|
||||||
|
|
||||||
chat = self._start_chat(history, params)
|
msg_params = {}
|
||||||
response = chat.send_message(question.content)
|
if "candidate_count" in params:
|
||||||
return ChatResult(
|
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||||
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
|
||||||
)
|
chat = self._start_chat(history, **params)
|
||||||
|
response = chat.send_message(question.content, **msg_params)
|
||||||
|
generations = [
|
||||||
|
ChatGeneration(message=AIMessage(content=r.text))
|
||||||
|
for r in response.candidates
|
||||||
|
]
|
||||||
|
return ChatResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
@ -219,11 +222,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
if examples:
|
if examples:
|
||||||
params["examples"] = _parse_examples(examples)
|
params["examples"] = _parse_examples(examples)
|
||||||
|
|
||||||
chat = self._start_chat(history, params)
|
msg_params = {}
|
||||||
response = await chat.send_message_async(question.content)
|
if "candidate_count" in params:
|
||||||
return ChatResult(
|
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||||
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
chat = self._start_chat(history, **params)
|
||||||
)
|
response = await chat.send_message_async(question.content, **msg_params)
|
||||||
|
generations = [
|
||||||
|
ChatGeneration(message=AIMessage(content=r.text))
|
||||||
|
for r in response.candidates
|
||||||
|
]
|
||||||
|
return ChatResult(generations=generations)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -239,7 +247,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
if examples:
|
if examples:
|
||||||
params["examples"] = _parse_examples(examples)
|
params["examples"] = _parse_examples(examples)
|
||||||
|
|
||||||
chat = self._start_chat(history, params)
|
chat = self._start_chat(history, **params)
|
||||||
responses = chat.send_message_streaming(question.content, **params)
|
responses = chat.send_message_streaming(question.content, **params)
|
||||||
for response in responses:
|
for response in responses:
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -247,11 +255,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
|||||||
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||||
|
|
||||||
def _start_chat(
|
def _start_chat(
|
||||||
self, history: _ChatHistory, params: dict
|
self, history: _ChatHistory, **kwargs: Any
|
||||||
) -> Union[ChatSession, CodeChatSession]:
|
) -> Union[ChatSession, CodeChatSession]:
|
||||||
if not self.is_codey_model:
|
if not self.is_codey_model:
|
||||||
return self.client.start_chat(
|
return self.client.start_chat(
|
||||||
context=history.context, message_history=history.history, **params
|
context=history.context, message_history=history.history, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.client.start_chat(message_history=history.history, **params)
|
return self.client.start_chat(message_history=history.history, **kwargs)
|
||||||
|
@ -175,7 +175,10 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||||
"when making API calls. If not provided, credentials will be ascertained from "
|
"when making API calls. If not provided, credentials will be ascertained from "
|
||||||
"the environment."
|
"the environment."
|
||||||
|
n: int = 1
|
||||||
|
"""How many completions to generate for each prompt."""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -203,6 +206,7 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
"max_output_tokens": self.max_output_tokens,
|
"max_output_tokens": self.max_output_tokens,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
|
"candidate_count": self.n,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -215,10 +219,16 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
def _prepare_params(
|
def _prepare_params(
|
||||||
self,
|
self,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
stop_sequences = stop or self.stop
|
stop_sequences = stop or self.stop
|
||||||
return {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
params_mapping = {"n": "candidate_count"}
|
||||||
|
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
|
||||||
|
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
|
||||||
|
if stream or self.streaming:
|
||||||
|
params.pop("candidate_count")
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
class VertexAI(_VertexAICommon, BaseLLM):
|
class VertexAI(_VertexAICommon, BaseLLM):
|
||||||
@ -260,6 +270,9 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
|
|
||||||
|
if values["streaming"] and values["n"] > 1:
|
||||||
|
raise ValueError("Only one candidate can be generated with streaming!")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -271,7 +284,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
params = self._prepare_params(stop=stop, **kwargs)
|
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
|
||||||
generations = []
|
generations = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
if should_stream:
|
if should_stream:
|
||||||
@ -285,7 +298,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
res = completion_with_retry(
|
res = completion_with_retry(
|
||||||
self, prompt, run_manager=run_manager, **params
|
self, prompt, run_manager=run_manager, **params
|
||||||
)
|
)
|
||||||
generations.append([_response_to_generation(res)])
|
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
@ -301,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
res = await acompletion_with_retry(
|
res = await acompletion_with_retry(
|
||||||
self, prompt, run_manager=run_manager, **params
|
self, prompt, run_manager=run_manager, **params
|
||||||
)
|
)
|
||||||
generations.append([_response_to_generation(res)])
|
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||||
return LLMResult(generations=generations)
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
@ -311,7 +324,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[GenerationChunk]:
|
) -> Iterator[GenerationChunk]:
|
||||||
params = self._prepare_params(stop=stop, **kwargs)
|
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||||
for stream_resp in stream_completion_with_retry(
|
for stream_resp in stream_completion_with_retry(
|
||||||
self, prompt, run_manager=run_manager, **params
|
self, prompt, run_manager=run_manager, **params
|
||||||
):
|
):
|
||||||
|
@ -5,7 +5,7 @@ if TYPE_CHECKING:
|
|||||||
from google.auth.credentials import Credentials
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
|
||||||
def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None:
|
def raise_vertex_import_error(minimum_expected_version: str = "1.35.0") -> None:
|
||||||
"""Raise ImportError related to Vertex SDK being not available.
|
"""Raise ImportError related to Vertex SDK being not available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -20,7 +20,10 @@ from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||||
def test_vertexai_instantiation(model_name: str) -> None:
|
def test_vertexai_instantiation(model_name: str) -> None:
|
||||||
model = ChatVertexAI(model_name=model_name)
|
if model_name:
|
||||||
|
model = ChatVertexAI(model_name=model_name)
|
||||||
|
else:
|
||||||
|
model = ChatVertexAI()
|
||||||
assert model._llm_type == "vertexai"
|
assert model._llm_type == "vertexai"
|
||||||
assert model.model_name == model.client._model_id
|
assert model.model_name == model.client._model_id
|
||||||
|
|
||||||
@ -38,6 +41,15 @@ def test_vertexai_single_call(model_name: str) -> None:
|
|||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_candidates() -> None:
|
||||||
|
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = model.generate(messages=[[message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 1
|
||||||
|
assert len(response.generations[0]) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertexai_agenerate() -> None:
|
async def test_vertexai_agenerate() -> None:
|
||||||
@ -153,7 +165,8 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|||||||
with patch(
|
with patch(
|
||||||
"vertexai.language_models._language_models.ChatModel.start_chat"
|
"vertexai.language_models._language_models.ChatModel.start_chat"
|
||||||
) as start_chat:
|
) as start_chat:
|
||||||
mock_response = Mock(text=response_text)
|
mock_response = MagicMock()
|
||||||
|
mock_response.candidates = [Mock(text=response_text)]
|
||||||
mock_chat = MagicMock()
|
mock_chat = MagicMock()
|
||||||
start_chat.return_value = mock_chat
|
start_chat.return_value = mock_chat
|
||||||
mock_send_message = MagicMock(return_value=mock_response)
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
@ -167,7 +180,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|||||||
response = model([message])
|
response = model([message])
|
||||||
|
|
||||||
assert response.content == response_text
|
assert response.content == response_text
|
||||||
mock_send_message.assert_called_once_with(user_prompt)
|
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
|
||||||
expected_stop_sequence = [stop] if stop else None
|
expected_stop_sequence = [stop] if stop else None
|
||||||
start_chat.assert_called_once_with(
|
start_chat.assert_called_once_with(
|
||||||
context=None,
|
context=None,
|
||||||
|
@ -29,29 +29,31 @@ def test_vertex_call() -> None:
|
|||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
def test_vertex_generate() -> None:
|
def test_vertex_generate() -> None:
|
||||||
llm = VertexAI(temperate=0)
|
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||||
output = llm.generate(["Please say foo:"])
|
output = llm.generate(["Say foo:"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 1
|
||||||
|
assert len(output.generations[0]) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_agenerate() -> None:
|
async def test_vertex_agenerate() -> None:
|
||||||
llm = VertexAI(temperate=0)
|
llm = VertexAI(temperature=0)
|
||||||
output = await llm.agenerate(["Please say foo:"])
|
output = await llm.agenerate(["Please say foo:"])
|
||||||
assert isinstance(output, LLMResult)
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.scheduled
|
@pytest.mark.scheduled
|
||||||
def test_vertex_stream() -> None:
|
def test_vertex_stream() -> None:
|
||||||
llm = VertexAI(temperate=0)
|
llm = VertexAI(temperature=0)
|
||||||
outputs = list(llm.stream("Please say foo:"))
|
outputs = list(llm.stream("Please say foo:"))
|
||||||
assert isinstance(outputs[0], str)
|
assert isinstance(outputs[0], str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_consistency() -> None:
|
async def test_vertex_consistency() -> None:
|
||||||
llm = VertexAI(temperate=0)
|
llm = VertexAI(temperature=0)
|
||||||
output = llm.generate(["Please say foo:"])
|
output = llm.generate(["Please say foo:"])
|
||||||
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
||||||
async_output = await llm.agenerate(["Please say foo:"])
|
async_output = await llm.agenerate(["Please say foo:"])
|
||||||
|
Loading…
Reference in New Issue
Block a user