mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
integrations[patch]: remove non-required chat param defaults (#26730)
anthropic: - max_retries openai: - n - temperature - max_retries fireworks - temperature groq - n - max_retries - temperature mistral - max_retries - timeout - max_concurrent_requests - temperature - top_p - safe_mode --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
b9db8e9921
commit
3d7ae8b5d2
@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
Key init args — client params:
|
||||
timeout: Optional[float]
|
||||
Timeout for requests.
|
||||
max_retries: int
|
||||
max_retries: Optional[int]
|
||||
Max number of retries if a request fails.
|
||||
api_key: Optional[str]
|
||||
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
|
||||
@ -558,8 +558,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
default_request_timeout: Optional[float] = Field(None, alias="timeout")
|
||||
"""Timeout for requests to Anthropic Completion API."""
|
||||
|
||||
# sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries
|
||||
max_retries: int = 2
|
||||
max_retries: Optional[int] = None
|
||||
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||
|
||||
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
|
||||
@ -662,9 +661,10 @@ class ChatAnthropic(BaseChatModel):
|
||||
client_params: Dict[str, Any] = {
|
||||
"api_key": self.anthropic_api_key.get_secret_value(),
|
||||
"base_url": self.anthropic_api_url,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": (self.default_headers or None),
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
# value <= 0 indicates the param should be ignored. None is a meaningful value
|
||||
# for Anthropic client and treated differently than not specifying the param at
|
||||
# all.
|
||||
|
@ -316,7 +316,7 @@ class ChatFireworks(BaseChatModel):
|
||||
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
||||
)
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
temperature: Optional[float] = None
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
|
@ -22,6 +22,7 @@
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
]),
|
||||
'temperature': 0.0,
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ChatFireworks',
|
||||
|
@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel):
|
||||
Key init args — client params:
|
||||
timeout: Union[float, Tuple[float, float], Any, None]
|
||||
Timeout for requests.
|
||||
max_retries: int
|
||||
max_retries: Optional[int]
|
||||
Max number of retries.
|
||||
api_key: Optional[str]
|
||||
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
|
||||
@ -303,7 +303,7 @@ class ChatGroq(BaseChatModel):
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
temperature: Optional[float] = None
|
||||
"""What sampling temperature to use."""
|
||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||
"""Default stop sequences."""
|
||||
@ -327,11 +327,11 @@ class ChatGroq(BaseChatModel):
|
||||
)
|
||||
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
max_retries: int = 2
|
||||
max_retries: Optional[int] = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
n: Optional[int] = None
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
@ -379,10 +379,11 @@ class ChatGroq(BaseChatModel):
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if self.n > 1 and self.streaming:
|
||||
elif self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
if self.temperature == 0:
|
||||
self.temperature = 1e-8
|
||||
|
||||
@ -392,10 +393,11 @@ class ChatGroq(BaseChatModel):
|
||||
),
|
||||
"base_url": self.groq_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
|
||||
try:
|
||||
import groq
|
||||
|
@ -17,7 +17,6 @@
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'mixtral-8x7b-32768',
|
||||
'n': 1,
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
]),
|
||||
|
@ -95,9 +95,12 @@ def _create_retry_decorator(
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||
|
||||
errors = [httpx.RequestError, httpx.StreamError]
|
||||
return create_base_retry_decorator(
|
||||
kwargs: dict = dict(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
return create_base_retry_decorator(
|
||||
**{k: v for k, v in kwargs.items() if v is not None}
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
|
||||
@ -380,13 +383,13 @@ class ChatMistralAI(BaseChatModel):
|
||||
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
||||
)
|
||||
endpoint: Optional[str] = Field(default=None, alias="base_url")
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
max_retries: Optional[int] = None
|
||||
timeout: Optional[int] = None
|
||||
max_concurrent_requests: Optional[int] = None
|
||||
model: str = Field(default="mistral-small", alias="model_name")
|
||||
temperature: float = 0.7
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: float = 1
|
||||
top_p: Optional[float] = None
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
random_seed: Optional[int] = None
|
||||
|
@ -9,7 +9,6 @@
|
||||
]),
|
||||
'kwargs': dict({
|
||||
'endpoint': 'boo',
|
||||
'max_concurrent_requests': 64,
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'mistral_api_key': dict({
|
||||
@ -22,7 +21,6 @@
|
||||
'model': 'mistral-small',
|
||||
'temperature': 0.0,
|
||||
'timeout': 60,
|
||||
'top_p': 1,
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ChatMistralAI',
|
||||
|
@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
|
||||
timeout: Union[float, Tuple[float, float], Any, None]
|
||||
Timeout for requests.
|
||||
max_retries: int
|
||||
max_retries: Optional[int]
|
||||
Max number of retries.
|
||||
organization: Optional[str]
|
||||
OpenAI organization ID. If not passed in will be read from env
|
||||
@ -586,9 +586,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if self.n > 1 and self.streaming:
|
||||
elif self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
if self.disabled_params is None:
|
||||
@ -641,10 +641,11 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"organization": self.openai_organization,
|
||||
"base_url": self.openai_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
if not self.client:
|
||||
sync_specific = {"http_client": self.http_client}
|
||||
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
||||
|
@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
temperature: Optional[float] = None
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
max_retries: int = 2
|
||||
max_retries: Optional[int] = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
presence_penalty: Optional[float] = None
|
||||
"""Penalizes repeated tokens."""
|
||||
@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
n: Optional[int] = None
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
top_p: Optional[float] = None
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
@ -532,9 +532,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if self.n > 1 and self.streaming:
|
||||
elif self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||
@ -551,10 +551,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"organization": self.openai_organization,
|
||||
"base_url": self.openai_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
|
||||
if self.openai_proxy and (self.http_client or self.http_async_client):
|
||||
openai_proxy = self.openai_proxy
|
||||
http_client = self.http_client
|
||||
@ -609,14 +611,14 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"stop": self.stop or None, # also exclude empty list for this
|
||||
"max_tokens": self.max_tokens,
|
||||
"extra_body": self.extra_body,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
}
|
||||
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**{k: v for k, v in exclude_if_none.items() if v is not None},
|
||||
**self.model_kwargs,
|
||||
}
|
||||
@ -1565,7 +1567,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
timeout: Union[float, Tuple[float, float], Any, None]
|
||||
Timeout for requests.
|
||||
max_retries: int
|
||||
max_retries: Optional[int]
|
||||
Max number of retries.
|
||||
api_key: Optional[str]
|
||||
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
|
||||
|
@ -15,7 +15,6 @@
|
||||
}),
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'n': 1,
|
||||
'openai_api_key': dict({
|
||||
'id': list([
|
||||
'AZURE_OPENAI_API_KEY',
|
||||
|
@ -11,7 +11,6 @@
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'gpt-3.5-turbo',
|
||||
'n': 1,
|
||||
'openai_api_key': dict({
|
||||
'id': list([
|
||||
'OPENAI_API_KEY',
|
||||
|
@ -877,8 +877,6 @@ def test__get_request_payload() -> None:
|
||||
],
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"stream": False,
|
||||
"n": 1,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
payload = llm._get_request_payload(messages)
|
||||
assert payload == expected
|
||||
|
@ -8,7 +8,13 @@ TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
test tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||
|
||||
integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
######################
|
||||
|
@ -320,9 +320,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
if self.n is not None and self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if self.n > 1 and self.streaming:
|
||||
if self.n is not None and self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
client_params: dict = {
|
||||
@ -331,10 +331,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
),
|
||||
"base_url": self.xai_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
if self.max_retries is not None:
|
||||
client_params["max_retries"] = self.max_retries
|
||||
|
||||
if client_params["api_key"] is None:
|
||||
raise ValueError(
|
||||
|
@ -10,7 +10,6 @@
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'grok-beta',
|
||||
'n': 1,
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
]),
|
||||
|
Loading…
Reference in New Issue
Block a user