openai[patch]: use max_completion_tokens in place of max_tokens (#26917)

`max_tokens` is deprecated:
https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
ccurme
2024-11-26 11:30:19 -05:00
committed by GitHub
parent 869c8f5879
commit 42b18824c2
3 changed files with 75 additions and 29 deletions

View File

@@ -435,7 +435,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
max_tokens: Optional[int] = None
max_tokens: Optional[int] = Field(default=None)
"""Maximum number of tokens to generate."""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
@@ -699,6 +699,7 @@ class BaseChatOpenAI(BaseChatModel):
messages = self._convert_input(input_).to_messages()
if stop is not None:
kwargs["stop"] = stop
return {
"messages": [_convert_message_to_dict(m) for m in messages],
**self._default_params,
@@ -853,7 +854,9 @@ class BaseChatOpenAI(BaseChatModel):
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
if ls_max_tokens := params.get("max_tokens", self.max_tokens) or params.get(
"max_completion_tokens", self.max_tokens
):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
@@ -1501,7 +1504,7 @@ class BaseChatOpenAI(BaseChatModel):
return filtered
class ChatOpenAI(BaseChatOpenAI):
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration.
.. dropdown:: Setup
@@ -1963,6 +1966,9 @@ class ChatOpenAI(BaseChatOpenAI):
message chunks will be generated during the stream including usage metadata.
"""
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
"""Maximum number of tokens to generate."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@@ -1992,6 +1998,29 @@ class ChatOpenAI(BaseChatOpenAI):
"""Return whether this model can be serialized by Langchain."""
return True
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
params = super()._default_params
if "max_tokens" in params:
params["max_completion_tokens"] = params.pop("max_tokens")
return params
def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# max_tokens was deprecated in favor of max_completion_tokens
# in September 2024 release
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload
def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool: