mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
changed default params for gemini (#16044)
Replace this entire comment with: - **Description:** changed default values for Vertex LLMs (to be handled on the SDK's side)
This commit is contained in:
parent
ec9642d667
commit
58f0ba306b
@ -41,6 +41,11 @@ from langchain_google_vertexai._utils import (
|
||||
is_gemini_model,
|
||||
)
|
||||
|
||||
_PALM_DEFAULT_MAX_OUTPUT_TOKENS = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
|
||||
_PALM_DEFAULT_TEMPERATURE = 0.0
|
||||
_PALM_DEFAULT_TOP_P = 0.95
|
||||
_PALM_DEFAULT_TOP_K = 40
|
||||
|
||||
|
||||
def _completion_with_retry(
|
||||
llm: VertexAI,
|
||||
@ -118,14 +123,14 @@ class _VertexAICommon(_VertexAIBase):
|
||||
client_preview: Any = None #: :meta private:
|
||||
model_name: str
|
||||
"Underlying model name."
|
||||
temperature: float = 0.0
|
||||
temperature: Optional[float] = None
|
||||
"Sampling temperature, it controls the degree of randomness in token selection."
|
||||
max_output_tokens: int = 128
|
||||
max_output_tokens: Optional[int] = None
|
||||
"Token limit determines the maximum amount of text output from one prompt."
|
||||
top_p: float = 0.95
|
||||
top_p: Optional[float] = None
|
||||
"Tokens are selected from most probable to least until the sum of their "
|
||||
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
||||
top_k: int = 40
|
||||
top_k: Optional[int] = None
|
||||
"How the model selects tokens for output, the next token is selected from "
|
||||
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
||||
credentials: Any = Field(default=None, exclude=True)
|
||||
@ -156,6 +161,15 @@ class _VertexAICommon(_VertexAIBase):
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
if self._is_gemini_model:
|
||||
default_params = {}
|
||||
else:
|
||||
default_params = {
|
||||
"temperature": _PALM_DEFAULT_TEMPERATURE,
|
||||
"max_output_tokens": _PALM_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
"top_p": _PALM_DEFAULT_TOP_P,
|
||||
"top_k": _PALM_DEFAULT_TOP_K,
|
||||
}
|
||||
params = {
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
@ -168,7 +182,14 @@ class _VertexAICommon(_VertexAIBase):
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
)
|
||||
return params
|
||||
updated_params = {}
|
||||
for param_name, param_value in params.items():
|
||||
default_value = default_params.get(param_name)
|
||||
if param_value or default_value:
|
||||
updated_params[param_name] = (
|
||||
param_value if param_value else default_value
|
||||
)
|
||||
return updated_params
|
||||
|
||||
@classmethod
|
||||
def _init_vertexai(cls, values: Dict) -> None:
|
||||
@ -314,7 +335,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
params = self._prepare_params(stop=stop, **kwargs)
|
||||
generations = []
|
||||
generations: List[List[Generation]] = []
|
||||
for prompt in prompts:
|
||||
res = await _acompletion_with_retry(
|
||||
self,
|
||||
|
@ -68,7 +68,7 @@ def test_vertexai_args_passed(stop: Optional[str]) -> None:
|
||||
mock_model.start_chat = mock_start_chat
|
||||
mg.return_value = mock_model
|
||||
|
||||
model = ChatVertexAI(**prompt_params)
|
||||
model = ChatVertexAI(**prompt_params) # type: ignore
|
||||
message = HumanMessage(content=user_prompt)
|
||||
if stop:
|
||||
response = model([message], stop=[stop])
|
||||
@ -110,3 +110,52 @@ def test_parse_chat_history_correct() -> None:
|
||||
ChatMessage(content=text_question, author="user"),
|
||||
ChatMessage(content=text_answer, author="bot"),
|
||||
]
|
||||
|
||||
|
||||
def test_default_params_palm() -> None:
|
||||
user_prompt = "Hello"
|
||||
|
||||
with patch("vertexai._model_garden._model_garden_models._from_pretrained") as mg:
|
||||
mock_response = MagicMock()
|
||||
mock_response.candidates = [Mock(text="Goodbye")]
|
||||
mock_chat = MagicMock()
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
mock_chat.send_message = mock_send_message
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_start_chat = MagicMock(return_value=mock_chat)
|
||||
mock_model.start_chat = mock_start_chat
|
||||
mg.return_value = mock_model
|
||||
|
||||
model = ChatVertexAI(model_name="text-bison@001")
|
||||
message = HumanMessage(content=user_prompt)
|
||||
_ = model([message])
|
||||
mock_start_chat.assert_called_once_with(
|
||||
context=None,
|
||||
message_history=[],
|
||||
max_output_tokens=128,
|
||||
top_k=40,
|
||||
top_p=0.95,
|
||||
stop_sequences=None,
|
||||
)
|
||||
|
||||
|
||||
def test_default_params_gemini() -> None:
|
||||
user_prompt = "Hello"
|
||||
|
||||
with patch("langchain_google_vertexai.chat_models.GenerativeModel") as gm:
|
||||
mock_response = MagicMock()
|
||||
content = Mock(parts=[Mock(function_call=None)])
|
||||
mock_response.candidates = [Mock(text="Goodbye", content=content)]
|
||||
mock_chat = MagicMock()
|
||||
mock_send_message = MagicMock(return_value=mock_response)
|
||||
mock_chat.send_message = mock_send_message
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_start_chat = MagicMock(return_value=mock_chat)
|
||||
mock_model.start_chat = mock_start_chat
|
||||
gm.return_value = mock_model
|
||||
model = ChatVertexAI(model_name="gemini-pro")
|
||||
message = HumanMessage(content=user_prompt)
|
||||
_ = model([message])
|
||||
mock_start_chat.assert_called_once_with(history=[])
|
||||
|
Loading…
Reference in New Issue
Block a user