mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Revert "integrations[patch]: remove non-required chat param defaults" (#29048)
Reverts langchain-ai/langchain#26730 discuss best way to release default changes (esp openai temperature)
This commit is contained in:
parent
3d7ae8b5d2
commit
187131c55c
@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
Key init args — client params:
|
Key init args — client params:
|
||||||
timeout: Optional[float]
|
timeout: Optional[float]
|
||||||
Timeout for requests.
|
Timeout for requests.
|
||||||
max_retries: Optional[int]
|
max_retries: int
|
||||||
Max number of retries if a request fails.
|
Max number of retries if a request fails.
|
||||||
api_key: Optional[str]
|
api_key: Optional[str]
|
||||||
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
|
Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY.
|
||||||
@ -558,7 +558,8 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
default_request_timeout: Optional[float] = Field(None, alias="timeout")
|
default_request_timeout: Optional[float] = Field(None, alias="timeout")
|
||||||
"""Timeout for requests to Anthropic Completion API."""
|
"""Timeout for requests to Anthropic Completion API."""
|
||||||
|
|
||||||
max_retries: Optional[int] = None
|
# sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries
|
||||||
|
max_retries: int = 2
|
||||||
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||||
|
|
||||||
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
|
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
|
||||||
@ -661,10 +662,9 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
client_params: Dict[str, Any] = {
|
client_params: Dict[str, Any] = {
|
||||||
"api_key": self.anthropic_api_key.get_secret_value(),
|
"api_key": self.anthropic_api_key.get_secret_value(),
|
||||||
"base_url": self.anthropic_api_url,
|
"base_url": self.anthropic_api_url,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
"default_headers": (self.default_headers or None),
|
"default_headers": (self.default_headers or None),
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
|
||||||
client_params["max_retries"] = self.max_retries
|
|
||||||
# value <= 0 indicates the param should be ignored. None is a meaningful value
|
# value <= 0 indicates the param should be ignored. None is a meaningful value
|
||||||
# for Anthropic client and treated differently than not specifying the param at
|
# for Anthropic client and treated differently than not specifying the param at
|
||||||
# all.
|
# all.
|
||||||
|
@ -316,7 +316,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
||||||
)
|
)
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: Optional[float] = None
|
temperature: float = 0.0
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
|
stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences")
|
||||||
"""Default stop sequences."""
|
"""Default stop sequences."""
|
||||||
|
@ -22,7 +22,6 @@
|
|||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
'temperature': 0.0,
|
|
||||||
}),
|
}),
|
||||||
'lc': 1,
|
'lc': 1,
|
||||||
'name': 'ChatFireworks',
|
'name': 'ChatFireworks',
|
||||||
|
@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
Key init args — client params:
|
Key init args — client params:
|
||||||
timeout: Union[float, Tuple[float, float], Any, None]
|
timeout: Union[float, Tuple[float, float], Any, None]
|
||||||
Timeout for requests.
|
Timeout for requests.
|
||||||
max_retries: Optional[int]
|
max_retries: int
|
||||||
Max number of retries.
|
Max number of retries.
|
||||||
api_key: Optional[str]
|
api_key: Optional[str]
|
||||||
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
|
Groq API key. If not passed in will be read from env var GROQ_API_KEY.
|
||||||
@ -303,7 +303,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
|
model_name: str = Field(default="mixtral-8x7b-32768", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: Optional[float] = None
|
temperature: float = 0.7
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
|
||||||
"""Default stop sequences."""
|
"""Default stop sequences."""
|
||||||
@ -327,11 +327,11 @@ class ChatGroq(BaseChatModel):
|
|||||||
)
|
)
|
||||||
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
|
"""Timeout for requests to Groq completion API. Can be float, httpx.Timeout or
|
||||||
None."""
|
None."""
|
||||||
max_retries: Optional[int] = None
|
max_retries: int = 2
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
n: Optional[int] = None
|
n: int = 1
|
||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
"""Maximum number of tokens to generate."""
|
"""Maximum number of tokens to generate."""
|
||||||
@ -379,11 +379,10 @@ class ChatGroq(BaseChatModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if self.n is not None and self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError("n must be at least 1.")
|
raise ValueError("n must be at least 1.")
|
||||||
elif self.n is not None and self.n > 1 and self.streaming:
|
if self.n > 1 and self.streaming:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
|
||||||
if self.temperature == 0:
|
if self.temperature == 0:
|
||||||
self.temperature = 1e-8
|
self.temperature = 1e-8
|
||||||
|
|
||||||
@ -393,11 +392,10 @@ class ChatGroq(BaseChatModel):
|
|||||||
),
|
),
|
||||||
"base_url": self.groq_api_base,
|
"base_url": self.groq_api_base,
|
||||||
"timeout": self.request_timeout,
|
"timeout": self.request_timeout,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
"default_headers": self.default_headers,
|
"default_headers": self.default_headers,
|
||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
|
||||||
client_params["max_retries"] = self.max_retries
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import groq
|
import groq
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'model_name': 'mixtral-8x7b-32768',
|
'model_name': 'mixtral-8x7b-32768',
|
||||||
|
'n': 1,
|
||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
|
@ -95,11 +95,8 @@ def _create_retry_decorator(
|
|||||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||||
|
|
||||||
errors = [httpx.RequestError, httpx.StreamError]
|
errors = [httpx.RequestError, httpx.StreamError]
|
||||||
kwargs: dict = dict(
|
|
||||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
|
||||||
)
|
|
||||||
return create_base_retry_decorator(
|
return create_base_retry_decorator(
|
||||||
**{k: v for k, v in kwargs.items() if v is not None}
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -383,13 +380,13 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
||||||
)
|
)
|
||||||
endpoint: Optional[str] = Field(default=None, alias="base_url")
|
endpoint: Optional[str] = Field(default=None, alias="base_url")
|
||||||
max_retries: Optional[int] = None
|
max_retries: int = 5
|
||||||
timeout: Optional[int] = None
|
timeout: int = 120
|
||||||
max_concurrent_requests: Optional[int] = None
|
max_concurrent_requests: int = 64
|
||||||
model: str = Field(default="mistral-small", alias="model_name")
|
model: str = Field(default="mistral-small", alias="model_name")
|
||||||
temperature: Optional[float] = None
|
temperature: float = 0.7
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: float = 1
|
||||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||||
random_seed: Optional[int] = None
|
random_seed: Optional[int] = None
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
]),
|
]),
|
||||||
'kwargs': dict({
|
'kwargs': dict({
|
||||||
'endpoint': 'boo',
|
'endpoint': 'boo',
|
||||||
|
'max_concurrent_requests': 64,
|
||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'mistral_api_key': dict({
|
'mistral_api_key': dict({
|
||||||
@ -21,6 +22,7 @@
|
|||||||
'model': 'mistral-small',
|
'model': 'mistral-small',
|
||||||
'temperature': 0.0,
|
'temperature': 0.0,
|
||||||
'timeout': 60,
|
'timeout': 60,
|
||||||
|
'top_p': 1,
|
||||||
}),
|
}),
|
||||||
'lc': 1,
|
'lc': 1,
|
||||||
'name': 'ChatMistralAI',
|
'name': 'ChatMistralAI',
|
||||||
|
@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
|
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
|
||||||
timeout: Union[float, Tuple[float, float], Any, None]
|
timeout: Union[float, Tuple[float, float], Any, None]
|
||||||
Timeout for requests.
|
Timeout for requests.
|
||||||
max_retries: Optional[int]
|
max_retries: int
|
||||||
Max number of retries.
|
Max number of retries.
|
||||||
organization: Optional[str]
|
organization: Optional[str]
|
||||||
OpenAI organization ID. If not passed in will be read from env
|
OpenAI organization ID. If not passed in will be read from env
|
||||||
@ -586,9 +586,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if self.n is not None and self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError("n must be at least 1.")
|
raise ValueError("n must be at least 1.")
|
||||||
elif self.n is not None and self.n > 1 and self.streaming:
|
if self.n > 1 and self.streaming:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
|
||||||
if self.disabled_params is None:
|
if self.disabled_params is None:
|
||||||
@ -641,11 +641,10 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
"organization": self.openai_organization,
|
"organization": self.openai_organization,
|
||||||
"base_url": self.openai_api_base,
|
"base_url": self.openai_api_base,
|
||||||
"timeout": self.request_timeout,
|
"timeout": self.request_timeout,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
"default_headers": self.default_headers,
|
"default_headers": self.default_headers,
|
||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
|
||||||
client_params["max_retries"] = self.max_retries
|
|
||||||
if not self.client:
|
if not self.client:
|
||||||
sync_specific = {"http_client": self.http_client}
|
sync_specific = {"http_client": self.http_client}
|
||||||
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
|
||||||
|
@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: Optional[float] = None
|
temperature: float = 0.7
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||||
None."""
|
None."""
|
||||||
max_retries: Optional[int] = None
|
max_retries: int = 2
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = None
|
||||||
"""Penalizes repeated tokens."""
|
"""Penalizes repeated tokens."""
|
||||||
@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"""Modify the likelihood of specified tokens appearing in the completion."""
|
"""Modify the likelihood of specified tokens appearing in the completion."""
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
"""Whether to stream the results or not."""
|
"""Whether to stream the results or not."""
|
||||||
n: Optional[int] = None
|
n: int = 1
|
||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
"""Total probability mass of tokens to consider at each step."""
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
@ -532,9 +532,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if self.n is not None and self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError("n must be at least 1.")
|
raise ValueError("n must be at least 1.")
|
||||||
elif self.n is not None and self.n > 1 and self.streaming:
|
if self.n > 1 and self.streaming:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
|
||||||
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
# Check OPENAI_ORGANIZATION for backwards compatibility.
|
||||||
@ -551,12 +551,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"organization": self.openai_organization,
|
"organization": self.openai_organization,
|
||||||
"base_url": self.openai_api_base,
|
"base_url": self.openai_api_base,
|
||||||
"timeout": self.request_timeout,
|
"timeout": self.request_timeout,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
"default_headers": self.default_headers,
|
"default_headers": self.default_headers,
|
||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
|
||||||
client_params["max_retries"] = self.max_retries
|
|
||||||
|
|
||||||
if self.openai_proxy and (self.http_client or self.http_async_client):
|
if self.openai_proxy and (self.http_client or self.http_async_client):
|
||||||
openai_proxy = self.openai_proxy
|
openai_proxy = self.openai_proxy
|
||||||
http_client = self.http_client
|
http_client = self.http_client
|
||||||
@ -611,14 +609,14 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"stop": self.stop or None, # also exclude empty list for this
|
"stop": self.stop or None, # also exclude empty list for this
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"extra_body": self.extra_body,
|
"extra_body": self.extra_body,
|
||||||
"n": self.n,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"reasoning_effort": self.reasoning_effort,
|
"reasoning_effort": self.reasoning_effort,
|
||||||
}
|
}
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
|
"n": self.n,
|
||||||
|
"temperature": self.temperature,
|
||||||
**{k: v for k, v in exclude_if_none.items() if v is not None},
|
**{k: v for k, v in exclude_if_none.items() if v is not None},
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
@ -1567,7 +1565,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
|
|
||||||
timeout: Union[float, Tuple[float, float], Any, None]
|
timeout: Union[float, Tuple[float, float], Any, None]
|
||||||
Timeout for requests.
|
Timeout for requests.
|
||||||
max_retries: Optional[int]
|
max_retries: int
|
||||||
Max number of retries.
|
Max number of retries.
|
||||||
api_key: Optional[str]
|
api_key: Optional[str]
|
||||||
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
|
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
}),
|
}),
|
||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
|
'n': 1,
|
||||||
'openai_api_key': dict({
|
'openai_api_key': dict({
|
||||||
'id': list([
|
'id': list([
|
||||||
'AZURE_OPENAI_API_KEY',
|
'AZURE_OPENAI_API_KEY',
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'model_name': 'gpt-3.5-turbo',
|
'model_name': 'gpt-3.5-turbo',
|
||||||
|
'n': 1,
|
||||||
'openai_api_key': dict({
|
'openai_api_key': dict({
|
||||||
'id': list([
|
'id': list([
|
||||||
'OPENAI_API_KEY',
|
'OPENAI_API_KEY',
|
||||||
|
@ -877,6 +877,8 @@ def test__get_request_payload() -> None:
|
|||||||
],
|
],
|
||||||
"model": "gpt-4o-2024-08-06",
|
"model": "gpt-4o-2024-08-06",
|
||||||
"stream": False,
|
"stream": False,
|
||||||
|
"n": 1,
|
||||||
|
"temperature": 0.7,
|
||||||
}
|
}
|
||||||
payload = llm._get_request_payload(messages)
|
payload = llm._get_request_payload(messages)
|
||||||
assert payload == expected
|
assert payload == expected
|
||||||
|
@ -8,13 +8,7 @@ TEST_FILE ?= tests/unit_tests/
|
|||||||
|
|
||||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||||
|
|
||||||
test tests:
|
test tests integration_test integration_tests:
|
||||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
|
||||||
|
|
||||||
test_watch:
|
|
||||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
|
||||||
|
|
||||||
integration_test integration_tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
######################
|
######################
|
||||||
|
@ -320,9 +320,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if self.n is not None and self.n < 1:
|
if self.n < 1:
|
||||||
raise ValueError("n must be at least 1.")
|
raise ValueError("n must be at least 1.")
|
||||||
if self.n is not None and self.n > 1 and self.streaming:
|
if self.n > 1 and self.streaming:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
|
||||||
client_params: dict = {
|
client_params: dict = {
|
||||||
@ -331,11 +331,10 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
),
|
),
|
||||||
"base_url": self.xai_api_base,
|
"base_url": self.xai_api_base,
|
||||||
"timeout": self.request_timeout,
|
"timeout": self.request_timeout,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
"default_headers": self.default_headers,
|
"default_headers": self.default_headers,
|
||||||
"default_query": self.default_query,
|
"default_query": self.default_query,
|
||||||
}
|
}
|
||||||
if self.max_retries is not None:
|
|
||||||
client_params["max_retries"] = self.max_retries
|
|
||||||
|
|
||||||
if client_params["api_key"] is None:
|
if client_params["api_key"] is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
'max_retries': 2,
|
'max_retries': 2,
|
||||||
'max_tokens': 100,
|
'max_tokens': 100,
|
||||||
'model_name': 'grok-beta',
|
'model_name': 'grok-beta',
|
||||||
|
'n': 1,
|
||||||
'request_timeout': 60.0,
|
'request_timeout': 60.0,
|
||||||
'stop': list([
|
'stop': list([
|
||||||
]),
|
]),
|
||||||
|
Loading…
Reference in New Issue
Block a user