mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 21:50:25 +00:00
added candidate_count for Vertex models (#11729)
- **Description:** added support for `candidate_count` parameter on Vertex
This commit is contained in:
parent
9d200e6cbe
commit
9f0a718198
@ -12,10 +12,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chat_models.base import BaseChatModel, _generate_from_stream
|
||||
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -177,16 +174,22 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
params = self._prepare_params(stop=stop, stream=False, **kwargs)
|
||||
examples = kwargs.get("examples", None)
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
response = chat.send_message(question.content)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
||||
)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
|
||||
chat = self._start_chat(history, **params)
|
||||
response = chat.send_message(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
]
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
@ -219,11 +222,16 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
response = await chat.send_message_async(question.content)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=response.text))]
|
||||
)
|
||||
msg_params = {}
|
||||
if "candidate_count" in params:
|
||||
msg_params["candidate_count"] = params.pop("candidate_count")
|
||||
chat = self._start_chat(history, **params)
|
||||
response = await chat.send_message_async(question.content, **msg_params)
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=r.text))
|
||||
for r in response.candidates
|
||||
]
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -239,7 +247,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
if examples:
|
||||
params["examples"] = _parse_examples(examples)
|
||||
|
||||
chat = self._start_chat(history, params)
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
if run_manager:
|
||||
@ -247,11 +255,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||
|
||||
def _start_chat(
|
||||
self, history: _ChatHistory, params: dict
|
||||
self, history: _ChatHistory, **kwargs: Any
|
||||
) -> Union[ChatSession, CodeChatSession]:
|
||||
if not self.is_codey_model:
|
||||
return self.client.start_chat(
|
||||
context=history.context, message_history=history.history, **params
|
||||
context=history.context, message_history=history.history, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.client.start_chat(message_history=history.history, **params)
|
||||
return self.client.start_chat(message_history=history.history, **kwargs)
|
||||
|
@ -175,7 +175,10 @@ class _VertexAICommon(_VertexAIBase):
|
||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||
"when making API calls. If not provided, credentials will be ascertained from "
|
||||
"the environment."
|
||||
n: int = 1
|
||||
"""How many completions to generate for each prompt."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -203,6 +206,7 @@ class _VertexAICommon(_VertexAIBase):
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"candidate_count": self.n,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -215,10 +219,16 @@ class _VertexAICommon(_VertexAIBase):
|
||||
def _prepare_params(
|
||||
self,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
stop_sequences = stop or self.stop
|
||||
return {**self._default_params, "stop_sequences": stop_sequences, **kwargs}
|
||||
params_mapping = {"n": "candidate_count"}
|
||||
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
|
||||
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
|
||||
if stream or self.streaming:
|
||||
params.pop("candidate_count")
|
||||
return params
|
||||
|
||||
|
||||
class VertexAI(_VertexAICommon, BaseLLM):
|
||||
@ -260,6 +270,9 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
||||
except ImportError:
|
||||
raise_vertex_import_error()
|
||||
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Only one candidate can be generated with streaming!")
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
@ -271,7 +284,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
if should_stream:
|
||||
@ -285,7 +298,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
res = completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(res)])
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
async def _agenerate(
|
||||
@ -301,7 +314,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
res = await acompletion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
)
|
||||
generations.append([_response_to_generation(res)])
|
||||
generations.append([_response_to_generation(r) for r in res.candidates])
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
def _stream(
|
||||
@ -311,7 +324,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
||||
for stream_resp in stream_completion_with_retry(
|
||||
self, prompt, run_manager=run_manager, **params
|
||||
):
|
||||
|
@ -5,7 +5,7 @@ if TYPE_CHECKING:
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.33.0") -> None:
|
||||
def raise_vertex_import_error(minimum_expected_version: str = "1.35.0") -> None:
|
||||
"""Raise ImportError related to Vertex SDK being not available.
|
||||
|
||||
Args:
|
||||
|
@ -20,7 +20,10 @@ from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
@pytest.mark.parametrize("model_name", [None, "codechat-bison", "chat-bison"])
|
||||
def test_vertexai_instantiation(model_name: str) -> None:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
if model_name:
|
||||
model = ChatVertexAI(model_name=model_name)
|
||||
else:
|
||||
model = ChatVertexAI()
|
||||
assert model._llm_type == "vertexai"
|
||||
assert model.model_name == model.client._model_id
|
||||
|
||||
@ -38,6 +41,15 @@ def test_vertexai_single_call(model_name: str) -> None:
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_candidates() -> None:
|
||||
model = ChatVertexAI(model_name="chat-bison@001", temperature=0.3, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = model.generate(messages=[[message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 1
|
||||
assert len(response.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertexai_agenerate() -> None:
|
||||
@ -153,7 +165,8 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
with patch(
|
||||
"vertexai.language_models._language_models.ChatModel.start_chat"
|
||||
) as start_chat:
|
||||
mock_response = Mock(text=response_text)
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [Mock(text=response_text)]
|
||||
mock_chat = MagicMock()
|
||||
start_chat.return_value = mock_chat
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
@ -167,7 +180,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
response = model([message])
|
||||
|
||||
assert response.content == response_text
|
||||
mock_send_message.assert_called_once_with(user_prompt)
|
||||
mock_send_message.assert_called_once_with(user_prompt, candidate_count=1)
|
||||
expected_stop_sequence = [stop] if stop else None
|
||||
start_chat.assert_called_once_with(
|
||||
context=None,
|
||||
|
@ -29,29 +29,31 @@ def test_vertex_call() -> None:
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_generate() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
output = llm.generate(["Please say foo:"])
|
||||
llm = VertexAI(temperature=0.3, n=2, model_name="text-bison@001")
|
||||
output = llm.generate(["Say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert len(output.generations) == 1
|
||||
assert len(output.generations[0]) == 2
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_agenerate() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
llm = VertexAI(temperature=0)
|
||||
output = await llm.agenerate(["Please say foo:"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_vertex_stream() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
llm = VertexAI(temperature=0)
|
||||
outputs = list(llm.stream("Please say foo:"))
|
||||
assert isinstance(outputs[0], str)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_consistency() -> None:
|
||||
llm = VertexAI(temperate=0)
|
||||
llm = VertexAI(temperature=0)
|
||||
output = llm.generate(["Please say foo:"])
|
||||
streaming_output = llm.generate(["Please say foo:"], stream=True)
|
||||
async_output = await llm.agenerate(["Please say foo:"])
|
||||
|
Loading…
Reference in New Issue
Block a user