mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-06 21:43:44 +00:00
openai[patch]: use max_completion_tokens in place of max_tokens (#26917)
`max_tokens` is deprecated: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -435,7 +435,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
top_p: Optional[float] = None
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
max_tokens: Optional[int] = None
|
||||
max_tokens: Optional[int] = Field(default=None)
|
||||
"""Maximum number of tokens to generate."""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
@@ -699,6 +699,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
if stop is not None:
|
||||
kwargs["stop"] = stop
|
||||
|
||||
return {
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
**self._default_params,
|
||||
@@ -853,7 +854,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens) or params.get(
|
||||
"max_completion_tokens", self.max_tokens
|
||||
):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
@@ -1501,7 +1504,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
return filtered
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI):
|
||||
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""OpenAI chat model integration.
|
||||
|
||||
.. dropdown:: Setup
|
||||
@@ -1963,6 +1966,9 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
@@ -1992,6 +1998,29 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
params = super()._default_params
|
||||
if "max_tokens" in params:
|
||||
params["max_completion_tokens"] = params.pop("max_tokens")
|
||||
|
||||
return params
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
# max_tokens was deprecated in favor of max_completion_tokens
|
||||
# in September 2024 release
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
return payload
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> bool:
|
||||
|
Reference in New Issue
Block a user