mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-02 19:34:04 +00:00
openai[patch]: use max_completion_tokens in place of max_tokens (#26917)
`max_tokens` is deprecated: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
869c8f5879
commit
42b18824c2
@ -435,7 +435,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
top_p: Optional[float] = None
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
max_tokens: Optional[int] = None
|
||||
max_tokens: Optional[int] = Field(default=None)
|
||||
"""Maximum number of tokens to generate."""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""The model name to pass to tiktoken when using this class.
|
||||
@ -699,6 +699,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
if stop is not None:
|
||||
kwargs["stop"] = stop
|
||||
|
||||
return {
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
**self._default_params,
|
||||
@ -853,7 +854,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
ls_model_type="chat",
|
||||
ls_temperature=params.get("temperature", self.temperature),
|
||||
)
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
|
||||
if ls_max_tokens := params.get("max_tokens", self.max_tokens) or params.get(
|
||||
"max_completion_tokens", self.max_tokens
|
||||
):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
@ -1501,7 +1504,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
return filtered
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatOpenAI):
|
||||
class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
"""OpenAI chat model integration.
|
||||
|
||||
.. dropdown:: Setup
|
||||
@ -1963,6 +1966,9 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
@ -1992,6 +1998,29 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
params = super()._default_params
|
||||
if "max_tokens" in params:
|
||||
params["max_completion_tokens"] = params.pop("max_tokens")
|
||||
|
||||
return params
|
||||
|
||||
def _get_request_payload(
|
||||
self,
|
||||
input_: LanguageModelInput,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
# max_tokens was deprecated in favor of max_completion_tokens
|
||||
# in September 2024 release
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
return payload
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> bool:
|
||||
|
@ -44,7 +44,7 @@ def test_chat_openai() -> None:
|
||||
max_retries=3,
|
||||
http_client=None,
|
||||
n=1,
|
||||
max_tokens=10,
|
||||
max_completion_tokens=10,
|
||||
default_headers=None,
|
||||
default_query=None,
|
||||
)
|
||||
@ -64,7 +64,7 @@ def test_chat_openai_model() -> None:
|
||||
|
||||
def test_chat_openai_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
chat = ChatOpenAI(max_completion_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([system_message, human_message])
|
||||
@ -75,7 +75,7 @@ def test_chat_openai_system_message() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_generate() -> None:
|
||||
"""Test ChatOpenAI wrapper with generate."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
chat = ChatOpenAI(max_completion_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -92,7 +92,7 @@ def test_chat_openai_generate() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_chat_openai_multiple_completions() -> None:
|
||||
"""Test ChatOpenAI wrapper with multiple completions."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=5)
|
||||
chat = ChatOpenAI(max_completion_tokens=10, n=5)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat._generate([message])
|
||||
assert isinstance(response, ChatResult)
|
||||
@ -108,7 +108,7 @@ def test_chat_openai_streaming() -> None:
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
max_completion_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
@ -133,7 +133,9 @@ def test_chat_openai_streaming_generation_info() -> None:
|
||||
|
||||
callback = _FakeCallback()
|
||||
callback_manager = CallbackManager([callback])
|
||||
chat = ChatOpenAI(max_tokens=2, temperature=0, callback_manager=callback_manager)
|
||||
chat = ChatOpenAI(
|
||||
max_completion_tokens=2, temperature=0, callback_manager=callback_manager
|
||||
)
|
||||
list(chat.stream("hi"))
|
||||
generation = callback.saved_things["generation"]
|
||||
# `Hello!` is two tokens, assert that that is what is returned
|
||||
@ -142,7 +144,7 @@ def test_chat_openai_streaming_generation_info() -> None:
|
||||
|
||||
def test_chat_openai_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10)
|
||||
chat = ChatOpenAI(max_completion_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
@ -151,7 +153,7 @@ def test_chat_openai_llm_output_contains_model_name() -> None:
|
||||
|
||||
def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatOpenAI(max_tokens=10, streaming=True)
|
||||
chat = ChatOpenAI(max_completion_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
@ -161,13 +163,13 @@ def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
|
||||
def test_chat_openai_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatOpenAI(max_tokens=10, streaming=True, temperature=0, n=5)
|
||||
ChatOpenAI(max_completion_tokens=10, streaming=True, temperature=0, n=5)
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
async def test_async_chat_openai() -> None:
|
||||
"""Test async generation."""
|
||||
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||
chat = ChatOpenAI(max_completion_tokens=10, n=2)
|
||||
message = HumanMessage(content="Hello")
|
||||
response = await chat.agenerate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -187,7 +189,7 @@ async def test_async_chat_openai_streaming() -> None:
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
chat = ChatOpenAI(
|
||||
max_tokens=10,
|
||||
max_completion_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
callback_manager=callback_manager,
|
||||
@ -219,7 +221,7 @@ async def test_async_chat_openai_bind_functions() -> None:
|
||||
default=None, title="Fav Food", description="The person's favorite food"
|
||||
)
|
||||
|
||||
chat = ChatOpenAI(max_tokens=30, n=1, streaming=True).bind_functions(
|
||||
chat = ChatOpenAI(max_completion_tokens=30, n=1, streaming=True).bind_functions(
|
||||
functions=[Person], function_call="Person"
|
||||
)
|
||||
|
||||
@ -241,7 +243,7 @@ async def test_async_chat_openai_bind_functions() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_streaming() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
@ -250,7 +252,7 @@ def test_openai_streaming() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
@ -259,7 +261,7 @@ async def test_openai_astream() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch() -> None:
|
||||
"""Test streaming tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
@ -269,7 +271,7 @@ async def test_openai_abatch() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
@ -281,7 +283,7 @@ async def test_openai_abatch_tags() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_batch() -> None:
|
||||
"""Test batch tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
@ -291,7 +293,7 @@ def test_openai_batch() -> None:
|
||||
@pytest.mark.scheduled
|
||||
async def test_openai_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
@ -300,7 +302,7 @@ async def test_openai_ainvoke() -> None:
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_invoke() -> None:
|
||||
"""Test invoke tokens from ChatOpenAI."""
|
||||
llm = ChatOpenAI(max_tokens=10)
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
@ -385,7 +387,7 @@ async def test_astream() -> None:
|
||||
assert chunks_with_token_counts == 0
|
||||
assert full.usage_metadata is None
|
||||
|
||||
llm = ChatOpenAI(temperature=0, max_tokens=5)
|
||||
llm = ChatOpenAI(temperature=0, max_completion_tokens=5)
|
||||
await _test_stream(llm.astream("Hello"), expect_usage=False)
|
||||
await _test_stream(
|
||||
llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True
|
||||
@ -393,7 +395,7 @@ async def test_astream() -> None:
|
||||
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True)
|
||||
llm = ChatOpenAI(
|
||||
temperature=0,
|
||||
max_tokens=5,
|
||||
max_completion_tokens=5,
|
||||
model_kwargs={"stream_options": {"include_usage": True}},
|
||||
)
|
||||
await _test_stream(llm.astream("Hello"), expect_usage=True)
|
||||
@ -401,7 +403,7 @@ async def test_astream() -> None:
|
||||
llm.astream("Hello", stream_options={"include_usage": False}),
|
||||
expect_usage=False,
|
||||
)
|
||||
llm = ChatOpenAI(temperature=0, max_tokens=5, stream_usage=True)
|
||||
llm = ChatOpenAI(temperature=0, max_completion_tokens=5, stream_usage=True)
|
||||
await _test_stream(llm.astream("Hello"), expect_usage=True)
|
||||
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)
|
||||
|
||||
@ -666,7 +668,7 @@ def test_openai_response_headers() -> None:
|
||||
"""Test ChatOpenAI response headers."""
|
||||
chat_openai = ChatOpenAI(include_response_headers=True)
|
||||
query = "I'm Pickle Rick"
|
||||
result = chat_openai.invoke(query, max_tokens=10)
|
||||
result = chat_openai.invoke(query, max_completion_tokens=10)
|
||||
headers = result.response_metadata["headers"]
|
||||
assert headers
|
||||
assert isinstance(headers, dict)
|
||||
@ -674,7 +676,7 @@ def test_openai_response_headers() -> None:
|
||||
|
||||
# Stream
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in chat_openai.stream(query, max_tokens=10):
|
||||
for chunk in chat_openai.stream(query, max_completion_tokens=10):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessage)
|
||||
headers = full.response_metadata["headers"]
|
||||
@ -687,7 +689,7 @@ async def test_openai_response_headers_async() -> None:
|
||||
"""Test ChatOpenAI response headers."""
|
||||
chat_openai = ChatOpenAI(include_response_headers=True)
|
||||
query = "I'm Pickle Rick"
|
||||
result = await chat_openai.ainvoke(query, max_tokens=10)
|
||||
result = await chat_openai.ainvoke(query, max_completion_tokens=10)
|
||||
headers = result.response_metadata["headers"]
|
||||
assert headers
|
||||
assert isinstance(headers, dict)
|
||||
@ -695,7 +697,7 @@ async def test_openai_response_headers_async() -> None:
|
||||
|
||||
# Stream
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
async for chunk in chat_openai.astream(query, max_tokens=10):
|
||||
async for chunk in chat_openai.astream(query, max_completion_tokens=10):
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessage)
|
||||
headers = full.response_metadata["headers"]
|
||||
@ -1085,3 +1087,13 @@ async def test_astream_response_format() -> None:
|
||||
"how are ya", response_format=Foo
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def test_o1_max_tokens() -> None:
|
||||
response = ChatOpenAI(model="o1-mini", max_tokens=10).invoke("how are you") # type: ignore[call-arg]
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
response = ChatOpenAI(model="gpt-4o", max_completion_tokens=10).invoke(
|
||||
"how are you"
|
||||
)
|
||||
assert isinstance(response, AIMessage)
|
||||
|
@ -36,6 +36,11 @@ def test_openai_model_param() -> None:
|
||||
llm = ChatOpenAI(model_name="foo") # type: ignore[call-arg]
|
||||
assert llm.model_name == "foo"
|
||||
|
||||
llm = ChatOpenAI(max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
llm = ChatOpenAI(max_completion_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
|
||||
|
||||
def test_openai_o1_temperature() -> None:
|
||||
llm = ChatOpenAI(model="o1-preview")
|
||||
|
Loading…
Reference in New Issue
Block a user