diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index fd64b824a8d..a297f12c6c4 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -307,7 +307,7 @@ class ChatAnthropic(BaseChatModel): Key init args — client params: timeout: Optional[float] Timeout for requests. - max_retries: int + max_retries: Optional[int] Max number of retries if a request fails. api_key: Optional[str] Anthropic API key. If not passed in will be read from env var ANTHROPIC_API_KEY. @@ -558,8 +558,7 @@ class ChatAnthropic(BaseChatModel): default_request_timeout: Optional[float] = Field(None, alias="timeout") """Timeout for requests to Anthropic Completion API.""" - # sdk default = 2: https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#retries - max_retries: int = 2 + max_retries: Optional[int] = None """Number of retries allowed for requests sent to the Anthropic Completion API.""" stop_sequences: Optional[List[str]] = Field(None, alias="stop") @@ -662,9 +661,10 @@ class ChatAnthropic(BaseChatModel): client_params: Dict[str, Any] = { "api_key": self.anthropic_api_key.get_secret_value(), "base_url": self.anthropic_api_url, - "max_retries": self.max_retries, "default_headers": (self.default_headers or None), } + if self.max_retries is not None: + client_params["max_retries"] = self.max_retries # value <= 0 indicates the param should be ignored. None is a meaningful value # for Anthropic client and treated differently than not specifying the param at # all. diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 92977104706..02ba31b47e7 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -316,7 +316,7 @@ class ChatFireworks(BaseChatModel): default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model" ) """Model name to use.""" - temperature: float = 0.0 + temperature: Optional[float] = None """What sampling temperature to use.""" stop: Optional[Union[str, List[str]]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" diff --git a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr index 4375bf55ff0..da33d819cd3 100644 --- a/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/fireworks/tests/unit_tests/__snapshots__/test_standard.ambr @@ -22,6 +22,7 @@ 'request_timeout': 60.0, 'stop': list([ ]), + 'temperature': 0.0, }), 'lc': 1, 'name': 'ChatFireworks', diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 5868e9cc6a3..838867dc37b 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -119,7 +119,7 @@ class ChatGroq(BaseChatModel): Key init args — client params: timeout: Union[float, Tuple[float, float], Any, None] Timeout for requests. - max_retries: int + max_retries: Optional[int] Max number of retries. api_key: Optional[str] Groq API key. If not passed in will be read from env var GROQ_API_KEY. @@ -303,7 +303,7 @@ class ChatGroq(BaseChatModel): async_client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field(default="mixtral-8x7b-32768", alias="model") """Model name to use.""" - temperature: float = 0.7 + temperature: Optional[float] = None """What sampling temperature to use.""" stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences") """Default stop sequences.""" @@ -327,11 +327,11 @@ class ChatGroq(BaseChatModel): ) """Timeout for requests to Groq completion API. Can be float, httpx.Timeout or None.""" - max_retries: int = 2 + max_retries: Optional[int] = None """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" - n: int = 1 + n: Optional[int] = None """Number of chat completions to generate for each prompt.""" max_tokens: Optional[int] = None """Maximum number of tokens to generate.""" @@ -379,10 +379,11 @@ class ChatGroq(BaseChatModel): @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n < 1: + if self.n is not None and self.n < 1: raise ValueError("n must be at least 1.") - if self.n > 1 and self.streaming: + elif self.n is not None and self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") + if self.temperature == 0: self.temperature = 1e-8 @@ -392,10 +393,11 @@ class ChatGroq(BaseChatModel): ), "base_url": self.groq_api_base, "timeout": self.request_timeout, - "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } + if self.max_retries is not None: + client_params["max_retries"] = self.max_retries try: import groq diff --git a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr index 741d2c84745..919d2a5c3d3 100644 --- a/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/groq/tests/unit_tests/__snapshots__/test_standard.ambr @@ -17,7 +17,6 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'mixtral-8x7b-32768', - 'n': 1, 'request_timeout': 60.0, 'stop': list([ ]), diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 686e4a7e6a8..63edab1f29a 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -95,9 +95,12 @@ def _create_retry_decorator( """Returns a tenacity retry decorator, preconfigured to handle exceptions""" errors = [httpx.RequestError, httpx.StreamError] - return create_base_retry_decorator( + kwargs: dict = dict( error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) + return create_base_retry_decorator( + **{k: v for k, v in kwargs.items() if v is not None} + ) def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool: @@ -380,13 +383,13 @@ class ChatMistralAI(BaseChatModel): default_factory=secret_from_env("MISTRAL_API_KEY", default=None), ) endpoint: Optional[str] = Field(default=None, alias="base_url") - max_retries: int = 5 - timeout: int = 120 - max_concurrent_requests: int = 64 + max_retries: Optional[int] = None + timeout: Optional[int] = None + max_concurrent_requests: Optional[int] = None model: str = Field(default="mistral-small", alias="model_name") - temperature: float = 0.7 + temperature: Optional[float] = None max_tokens: Optional[int] = None - top_p: float = 1 + top_p: Optional[float] = None """Decode using nucleus sampling: consider the smallest set of tokens whose probability sum is at least top_p. Must be in the closed interval [0.0, 1.0].""" random_seed: Optional[int] = None diff --git a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr index f7986097c47..07e4f33f3ce 100644 --- a/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/partners/mistralai/tests/unit_tests/__snapshots__/test_standard.ambr @@ -9,7 +9,6 @@ ]), 'kwargs': dict({ 'endpoint': 'boo', - 'max_concurrent_requests': 64, 'max_retries': 2, 'max_tokens': 100, 'mistral_api_key': dict({ @@ -22,7 +21,6 @@ 'model': 'mistral-small', 'temperature': 0.0, 'timeout': 60, - 'top_p': 1, }), 'lc': 1, 'name': 'ChatMistralAI', diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 2e1e5f8abfe..c2de17988cb 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI): https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning timeout: Union[float, Tuple[float, float], Any, None] Timeout for requests. - max_retries: int + max_retries: Optional[int] Max number of retries. organization: Optional[str] OpenAI organization ID. If not passed in will be read from env @@ -586,9 +586,9 @@ class AzureChatOpenAI(BaseChatOpenAI): @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n < 1: + if self.n is not None and self.n < 1: raise ValueError("n must be at least 1.") - if self.n > 1 and self.streaming: + elif self.n is not None and self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") if self.disabled_params is None: @@ -641,10 +641,11 @@ class AzureChatOpenAI(BaseChatOpenAI): "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, - "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } + if self.max_retries is not None: + client_params["max_retries"] = self.max_retries if not self.client: sync_specific = {"http_client": self.http_client} self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type] diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 142e7eca1a8..546a33c720e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel): root_async_client: Any = Field(default=None, exclude=True) #: :meta private: model_name: str = Field(default="gpt-3.5-turbo", alias="model") """Model name to use.""" - temperature: float = 0.7 + temperature: Optional[float] = None """What sampling temperature to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" @@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel): ) """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" - max_retries: int = 2 + max_retries: Optional[int] = None """Maximum number of retries to make when generating.""" presence_penalty: Optional[float] = None """Penalizes repeated tokens.""" @@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel): """Modify the likelihood of specified tokens appearing in the completion.""" streaming: bool = False """Whether to stream the results or not.""" - n: int = 1 + n: Optional[int] = None """Number of chat completions to generate for each prompt.""" top_p: Optional[float] = None """Total probability mass of tokens to consider at each step.""" @@ -532,9 +532,9 @@ class BaseChatOpenAI(BaseChatModel): @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n < 1: + if self.n is not None and self.n < 1: raise ValueError("n must be at least 1.") - if self.n > 1 and self.streaming: + elif self.n is not None and self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") # Check OPENAI_ORGANIZATION for backwards compatibility. @@ -551,10 +551,12 @@ class BaseChatOpenAI(BaseChatModel): "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, - "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } + if self.max_retries is not None: + client_params["max_retries"] = self.max_retries + if self.openai_proxy and (self.http_client or self.http_async_client): openai_proxy = self.openai_proxy http_client = self.http_client @@ -609,14 +611,14 @@ class BaseChatOpenAI(BaseChatModel): "stop": self.stop or None, # also exclude empty list for this "max_tokens": self.max_tokens, "extra_body": self.extra_body, + "n": self.n, + "temperature": self.temperature, "reasoning_effort": self.reasoning_effort, } params = { "model": self.model_name, "stream": self.streaming, - "n": self.n, - "temperature": self.temperature, **{k: v for k, v in exclude_if_none.items() if v is not None}, **self.model_kwargs, } @@ -1565,7 +1567,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] timeout: Union[float, Tuple[float, float], Any, None] Timeout for requests. - max_retries: int + max_retries: Optional[int] Max number of retries. api_key: Optional[str] OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY. diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr index 2b8c3563b94..2060512958a 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_azure_standard.ambr @@ -15,7 +15,6 @@ }), 'max_retries': 2, 'max_tokens': 100, - 'n': 1, 'openai_api_key': dict({ 'id': list([ 'AZURE_OPENAI_API_KEY', diff --git a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr index b7ab1ce9c07..e7307c6158f 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr +++ b/libs/partners/openai/tests/unit_tests/chat_models/__snapshots__/test_base_standard.ambr @@ -11,7 +11,6 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'gpt-3.5-turbo', - 'n': 1, 'openai_api_key': dict({ 'id': list([ 'OPENAI_API_KEY', diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 2e6cca0cd2d..5eac32c0447 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -877,8 +877,6 @@ def test__get_request_payload() -> None: ], "model": "gpt-4o-2024-08-06", "stream": False, - "n": 1, - "temperature": 0.7, } payload = llm._get_request_payload(messages) assert payload == expected diff --git a/libs/partners/xai/Makefile b/libs/partners/xai/Makefile index 1626a01bc49..6859cc789a1 100644 --- a/libs/partners/xai/Makefile +++ b/libs/partners/xai/Makefile @@ -8,7 +8,13 @@ TEST_FILE ?= tests/unit_tests/ integration_test integration_tests: TEST_FILE=tests/integration_tests/ -test tests integration_test integration_tests: +test tests: + poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE) + +test_watch: + poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) + +integration_test integration_tests: poetry run pytest $(TEST_FILE) ###################### diff --git a/libs/partners/xai/langchain_xai/chat_models.py b/libs/partners/xai/langchain_xai/chat_models.py index 775d22740cd..a854be5487d 100644 --- a/libs/partners/xai/langchain_xai/chat_models.py +++ b/libs/partners/xai/langchain_xai/chat_models.py @@ -320,9 +320,9 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if self.n < 1: + if self.n is not None and self.n < 1: raise ValueError("n must be at least 1.") - if self.n > 1 and self.streaming: + if self.n is not None and self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") client_params: dict = { @@ -331,10 +331,11 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override] ), "base_url": self.xai_api_base, "timeout": self.request_timeout, - "max_retries": self.max_retries, "default_headers": self.default_headers, "default_query": self.default_query, } + if self.max_retries is not None: + client_params["max_retries"] = self.max_retries if client_params["api_key"] is None: raise ValueError( diff --git a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr index 5c6f113f217..4cd1261555c 100644 --- a/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr +++ b/libs/partners/xai/tests/unit_tests/__snapshots__/test_chat_models_standard.ambr @@ -10,7 +10,6 @@ 'max_retries': 2, 'max_tokens': 100, 'model_name': 'grok-beta', - 'n': 1, 'request_timeout': 60.0, 'stop': list([ ]),