mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 09:40:26 +00:00
Standardized openai init params (#21739)
## Patch Summary community:openai[patch]: standardize init args ## Details I made changes to the OpenAI Chat API wrapper test in the Langchain open-source repository - **File**: `libs/community/tests/unit_tests/chat_models/test_openai.py` - **Changes**: - Updated `max_retries` with Pydantic Field - Updated the corresponding unit test - **Related Issues**: #20085 - Updated max_retries with Pydantic Field, updated the unit test. --------- Co-authored-by: JuHyung Son <sonju0427@gmail.com>
This commit is contained in:
parent
c03fd93fc1
commit
eca8c4bcc6
@ -1,4 +1,5 @@
|
||||
"""OpenAI chat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@ -217,7 +218,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
max_retries: int = 2
|
||||
max_retries: int = Field(default=2)
|
||||
"""Maximum number of retries to make when generating."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -17,10 +18,19 @@ from langchain_community.chat_models.openai import ChatOpenAI
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_openai_model_param() -> None:
|
||||
llm = ChatOpenAI(model="foo", openai_api_key="foo") # type: ignore[call-arg]
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatOpenAI(model_name="foo", openai_api_key="foo") # type: ignore[call-arg]
|
||||
assert llm.model_name == "foo"
|
||||
test_cases: List[dict] = [
|
||||
{"model_name": "foo", "openai_api_key": "foo"},
|
||||
{"model": "foo", "openai_api_key": "foo"},
|
||||
{"model_name": "foo", "api_key": "foo"},
|
||||
{"model_name": "foo", "openai_api_key": "foo", "max_retries": 2},
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
llm = ChatOpenAI(**case)
|
||||
assert llm.model_name == "foo", "Model name should be 'foo'"
|
||||
assert llm.openai_api_key == "foo", "API key should be 'foo'"
|
||||
assert hasattr(llm, "max_retries"), "max_retries attribute should exist"
|
||||
assert llm.max_retries == 2, "max_retries default should be set to 2"
|
||||
|
||||
|
||||
def test_function_message_dict_to_function_message() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user