mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
changed default params for gemini (#16044)
Replace this entire comment with: - **Description:** changed default values for Vertex LLMs (to be handled on the SDK's side)
This commit is contained in:
parent
ec9642d667
commit
58f0ba306b
@ -41,6 +41,11 @@ from langchain_google_vertexai._utils import (
|
|||||||
is_gemini_model,
|
is_gemini_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
|
||||||
|
_PALM_DEFAULT_TEMPERATURE = 0.0
|
||||||
|
_PALM_DEFAULT_TOP_P = 0.95
|
||||||
|
_PALM_DEFAULT_TOP_K = 40
|
||||||
|
|
||||||
|
|
||||||
def _completion_with_retry(
|
def _completion_with_retry(
|
||||||
llm: VertexAI,
|
llm: VertexAI,
|
||||||
@ -118,14 +123,14 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
client_preview: Any = None #: :meta private:
|
client_preview: Any = None #: :meta private:
|
||||||
model_name: str
|
model_name: str
|
||||||
"Underlying model name."
|
"Underlying model name."
|
||||||
temperature: float = 0.0
|
temperature: Optional[float] = None
|
||||||
"Sampling temperature, it controls the degree of randomness in token selection."
|
"Sampling temperature, it controls the degree of randomness in token selection."
|
||||||
max_output_tokens: int = 128
|
max_output_tokens: Optional[int] = None
|
||||||
"Token limit determines the maximum amount of text output from one prompt."
|
"Token limit determines the maximum amount of text output from one prompt."
|
||||||
top_p: float = 0.95
|
top_p: Optional[float] = None
|
||||||
"Tokens are selected from most probable to least until the sum of their "
|
"Tokens are selected from most probable to least until the sum of their "
|
||||||
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
||||||
top_k: int = 40
|
top_k: Optional[int] = None
|
||||||
"How the model selects tokens for output, the next token is selected from "
|
"How the model selects tokens for output, the next token is selected from "
|
||||||
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
||||||
credentials: Any = Field(default=None, exclude=True)
|
credentials: Any = Field(default=None, exclude=True)
|
||||||
@ -156,6 +161,15 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
if self._is_gemini_model:
|
||||||
|
default_params = {}
|
||||||
|
else:
|
||||||
|
default_params = {
|
||||||
|
"temperature": _PALM_DEFAULT_TEMPERATURE,
|
||||||
|
"max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||||
|
"top_p": _PALM_DEFAULT_TOP_P,
|
||||||
|
"top_k": _PALM_DEFAULT_TOP_K,
|
||||||
|
}
|
||||||
params = {
|
params = {
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"max_output_tokens": self.max_output_tokens,
|
"max_output_tokens": self.max_output_tokens,
|
||||||
@ -168,7 +182,14 @@ class _VertexAICommon(_VertexAIBase):
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
updated_params = {}
|
||||||
|
for param_name, param_value in params.items():
|
||||||
|
default_value = default_params.get(param_name)
|
||||||
|
if param_value or default_value:
|
||||||
|
updated_params[param_name] = (
|
||||||
|
param_value if param_value else default_value
|
||||||
|
)
|
||||||
|
return updated_params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _init_vertexai(cls, values: Dict) -> None:
|
def _init_vertexai(cls, values: Dict) -> None:
|
||||||
@ -314,7 +335,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
params = self._prepare_params(stop=stop, **kwargs)
|
params = self._prepare_params(stop=stop, **kwargs)
|
||||||
generations = []
|
generations: List[List[Generation]] = []
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
res = await _acompletion_with_retry(
|
res = await _acompletion_with_retry(
|
||||||
self,
|
self,
|
||||||
|
@ -68,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
|||||||
mock_model.start_chat = mock_start_chat
|
mock_model.start_chat = mock_start_chat
|
||||||
mg.return_value = mock_model
|
mg.return_value = mock_model
|
||||||
|
|
||||||
model = ChatVertexAI(**prompt_params)
|
model = ChatVertexAI(**prompt_params) # type: ignore
|
||||||
message = HumanMessage(content=user_prompt)
|
message = HumanMessage(content=user_prompt)
|
||||||
if stop:
|
if stop:
|
||||||
response = model([message], stop=[stop])
|
response = model([message], stop=[stop])
|
||||||
@ -110,3 +110,52 @@ def test_parse_chat_history_correct() -> None:
|
|||||||
ChatMessage(content=text_question, author="user"),
|
ChatMessage(content=text_question, author="user"),
|
||||||
ChatMessage(content=text_answer, author="bot"),
|
ChatMessage(content=text_answer, author="bot"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_params_palm() -> None:
|
||||||
|
user_prompt = "Hello"
|
||||||
|
|
||||||
|
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.candidates = [Mock(text="Goodbye")]
|
||||||
|
mock_chat = MagicMock()
|
||||||
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
|
mock_chat.send_message = mock_send_message
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_start_chat = MagicMock(return_value=mock_chat)
|
||||||
|
mock_model.start_chat = mock_start_chat
|
||||||
|
mg.return_value = mock_model
|
||||||
|
|
||||||
|
model = ChatVertexAI(model_name="text-bison@001")
|
||||||
|
message = HumanMessage(content=user_prompt)
|
||||||
|
_ = model([message])
|
||||||
|
mock_start_chat.assert_called_once_with(
|
||||||
|
context=None,
|
||||||
|
message_history=[],
|
||||||
|
max_output_tokens=128,
|
||||||
|
top_k=40,
|
||||||
|
top_p=0.95,
|
||||||
|
stop_sequences=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_params_gemini() -> None:
|
||||||
|
user_prompt = "Hello"
|
||||||
|
|
||||||
|
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
content = Mock(parts=[Mock(function_call=None)])
|
||||||
|
mock_response.candidates = [Mock(text="Goodbye", content=content)]
|
||||||
|
mock_chat = MagicMock()
|
||||||
|
mock_send_message = MagicMock(return_value=mock_response)
|
||||||
|
mock_chat.send_message = mock_send_message
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_start_chat = MagicMock(return_value=mock_chat)
|
||||||
|
mock_model.start_chat = mock_start_chat
|
||||||
|
gm.return_value = mock_model
|
||||||
|
model = ChatVertexAI(model_name="gemini-pro")
|
||||||
|
message = HumanMessage(content=user_prompt)
|
||||||
|
_ = model([message])
|
||||||
|
mock_start_chat.assert_called_once_with(history=[])
|
||||||
|
Loading…
Reference in New Issue
Block a user