Standardized openai init params (#21739)

## Patch Summary
community:openai[patch]: standardize init args

## Details
I made changes to the OpenAI Chat API wrapper test in the Langchain
open-source repository

- **File**: `libs/community/tests/unit_tests/chat_models/test_openai.py`
- **Changes**:
  - Updated `max_retries` with Pydantic Field
  - Updated the corresponding unit test
- **Related Issues**: #20085
  - Updated max_retries with Pydantic Field, updated the unit test.

---------

Co-authored-by: JuHyung Son <sonju0427@gmail.com>
This commit is contained in:
Kyle Cassidy 2024-05-16 12:30:52 -04:00 committed by GitHub
parent c03fd93fc1
commit eca8c4bcc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 6 deletions

View File

@ -1,4 +1,5 @@
"""OpenAI chat wrapper.""" """OpenAI chat wrapper."""
from __future__ import annotations from __future__ import annotations
import logging import logging
@ -217,7 +218,7 @@ class ChatOpenAI(BaseChatModel):
) )
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None.""" None."""
max_retries: int = 2 max_retries: int = Field(default=2)
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""

View File

@ -1,6 +1,7 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from typing import Any from typing import Any, List
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -17,10 +18,19 @@ from langchain_community.chat_models.openai import ChatOpenAI
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_openai_model_param() -> None: def test_openai_model_param() -> None:
llm = ChatOpenAI(model="foo", openai_api_key="foo") # type: ignore[call-arg] test_cases: List[dict] = [
assert llm.model_name == "foo" {"model_name": "foo", "openai_api_key": "foo"},
llm = ChatOpenAI(model_name="foo", openai_api_key="foo") # type: ignore[call-arg] {"model": "foo", "openai_api_key": "foo"},
assert llm.model_name == "foo" {"model_name": "foo", "api_key": "foo"},
{"model_name": "foo", "openai_api_key": "foo", "max_retries": 2},
]
for case in test_cases:
llm = ChatOpenAI(**case)
assert llm.model_name == "foo", "Model name should be 'foo'"
assert llm.openai_api_key == "foo", "API key should be 'foo'"
assert hasattr(llm, "max_retries"), "max_retries attribute should exist"
assert llm.max_retries == 2, "max_retries default should be set to 2"
def test_function_message_dict_to_function_message() -> None: def test_function_message_dict_to_function_message() -> None: