integrations[patch]: remove non-required chat param defaults (#26730)

anthropic:
  - max_retries

openai:
  - n
  - temperature
  - max_retries

fireworks
  - temperature

groq
  - n
  - max_retries
  - temperature

mistral
  - max_retries
  - timeout
  - max_concurrent_requests
  - temperature
  - top_p
  - safe_mode

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Bagatur 2025-01-06 17:26:22 -05:00 committed by GitHub
parent b9db8e9921
commit 3d7ae8b5d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 51 additions and 43 deletions

View File

@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel):
Key init args client params:
timeout: Optional[float]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries if a request fails.
api_key: Optional[str]
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
@ -558,8 +558,7 @@ class ChatAnthropic(BaseChatModel):
default_request_timeout: Optional[float] = Field(None, alias="timeout")
"""Timeout for requests to Anthropic Completion API."""
# sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries
max_retries: int = 2
max_retries: Optional[int] = None
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
@ -662,9 +661,10 @@ class ChatAnthropic(BaseChatModel):
client_params: Dict[str, Any] = {
"api_key": self.anthropic_api_key.get_secret_value(),
"base_url": self.anthropic_api_url,
"max_retries": self.max_retries,
"default_headers": (self.default_headers or None),
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
# value <= 0 indicates the param should be ignored. None is a meaningful value
# for Anthropic client and treated differently than not specifying the param at
# all.

View File

@ -316,7 +316,7 @@ class ChatFireworks(BaseChatModel):
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
)
"""Model name to use."""
temperature: float = 0.0
temperature: Optional[float] = None
"""What sampling temperature to use."""
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""

View File

@ -22,6 +22,7 @@
'request_timeout': 60.0,
'stop': list([
]),
'temperature': 0.0,
}),
'lc': 1,
'name': 'ChatFireworks',

View File

@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel):
Key init args client params:
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
api_key: Optional[str]
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
@ -303,7 +303,7 @@ class ChatGroq(BaseChatModel):
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
"""Model name to use."""
temperature: float = 0.7
temperature: Optional[float] = None
"""What sampling temperature to use."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
@ -327,11 +327,11 @@ class ChatGroq(BaseChatModel):
)
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
n: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
@ -379,10 +379,11 @@ class ChatGroq(BaseChatModel):
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
if self.temperature == 0:
self.temperature = 1e-8
@ -392,10 +393,11 @@ class ChatGroq(BaseChatModel):
),
"base_url": self.groq_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
try:
import groq

View File

@ -17,7 +17,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'mixtral-8x7b-32768',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),

View File

@ -95,9 +95,12 @@ def _create_retry_decorator(
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
errors = [httpx.RequestError, httpx.StreamError]
return create_base_retry_decorator(
kwargs: dict = dict(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
return create_base_retry_decorator(
**{k: v for k, v in kwargs.items() if v is not None}
)
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
@ -380,13 +383,13 @@ class ChatMistralAI(BaseChatModel):
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
)
endpoint: Optional[str] = Field(default=None, alias="base_url")
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
max_retries: Optional[int] = None
timeout: Optional[int] = None
max_concurrent_requests: Optional[int] = None
model: str = Field(default="mistral-small", alias="model_name")
temperature: float = 0.7
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: float = 1
top_p: Optional[float] = None
"""Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
random_seed: Optional[int] = None

View File

@ -9,7 +9,6 @@
]),
'kwargs': dict({
'endpoint': 'boo',
'max_concurrent_requests': 64,
'max_retries': 2,
'max_tokens': 100,
'mistral_api_key': dict({
@ -22,7 +21,6 @@
'model': 'mistral-small',
'temperature': 0.0,
'timeout': 60,
'top_p': 1,
}),
'lc': 1,
'name': 'ChatMistralAI',

View File

@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
@ -586,9 +586,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
if self.disabled_params is None:
@ -641,10 +641,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
if not self.client:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]

View File

@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
temperature: Optional[float] = None
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens."""
@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
n: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
@ -532,9 +532,9 @@ class BaseChatOpenAI(BaseChatModel):
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
# Check OPENAI_ORGANIZATION for backwards compatibility.
@ -551,10 +551,12 @@ class BaseChatOpenAI(BaseChatModel):
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
if self.openai_proxy and (self.http_client or self.http_async_client):
openai_proxy = self.openai_proxy
http_client = self.http_client
@ -609,14 +611,14 @@ class BaseChatOpenAI(BaseChatModel):
"stop": self.stop or None, # also exclude empty list for this
"max_tokens": self.max_tokens,
"extra_body": self.extra_body,
"n": self.n,
"temperature": self.temperature,
"reasoning_effort": self.reasoning_effort,
}
params = {
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**{k: v for k, v in exclude_if_none.items() if v is not None},
**self.model_kwargs,
}
@ -1565,7 +1567,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.

View File

@ -15,7 +15,6 @@
}),
'max_retries': 2,
'max_tokens': 100,
'n': 1,
'openai_api_key': dict({
'id': list([
'AZURE_OPENAI_API_KEY',

View File

@ -11,7 +11,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'n': 1,
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',

View File

@ -877,8 +877,6 @@ def test__get_request_payload() -> None:
],
"model": "gpt-4o-2024-08-06",
"stream": False,
"n": 1,
"temperature": 0.7,
}
payload = llm._get_request_payload(messages)
assert payload == expected

View File

@ -8,7 +8,13 @@ TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests integration_test integration_tests:
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################

View File

@ -320,9 +320,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
if self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
client_params: dict = {
@ -331,10 +331,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
),
"base_url": self.xai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
if client_params["api_key"] is None:
raise ValueError(

View File

@ -10,7 +10,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),