openai[patch]: type reasoning_effort (#28825)

This commit is contained in:
Bagatur 2024-12-19 11:36:49 -08:00 committed by GitHub
parent 6a37899b39
commit 1378ddfa5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 18 deletions

View File

@ -454,6 +454,16 @@ class BaseChatOpenAI(BaseChatModel):
"""Total probability mass of tokens to consider at each step.""" """Total probability mass of tokens to consider at each step."""
max_tokens: Optional[int] = Field(default=None) max_tokens: Optional[int] = Field(default=None)
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
reasoning_effort: Optional[str] = None
"""Constrains effort on reasoning for reasoning models.
o1 models only.
Currently supported values are low, medium, and high. Reducing reasoning effort
can result in faster responses and fewer tokens used on reasoning in a response.
.. versionadded:: 0.2.14
"""
tiktoken_model_name: Optional[str] = None tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class. """The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain Tiktoken is used to count the number of tokens in documents to constrain
@ -599,6 +609,7 @@ class BaseChatOpenAI(BaseChatModel):
"stop": self.stop or None, # also exclude empty list for this "stop": self.stop or None, # also exclude empty list for this
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"extra_body": self.extra_body, "extra_body": self.extra_body,
"reasoning_effort": self.reasoning_effort,
} }
params = { params = {

View File

@ -546,13 +546,13 @@ url = "../../standard-tests"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.2.3" version = "0.2.4"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = "<4.0,>=3.9" python-versions = "<4.0,>=3.9"
files = [ files = [
{file = "langsmith-0.2.3-py3-none-any.whl", hash = "sha256:4958b6e918f57fedba6ddc55b8534d1e06478bb44c779aa73713ce898ca6ae87"}, {file = "langsmith-0.2.4-py3-none-any.whl", hash = "sha256:fa797e4ecba968b76bccf351053e48bd6c6de7455515588cb46b74765e8a4127"},
{file = "langsmith-0.2.3.tar.gz", hash = "sha256:54c231b07fdeb0f8472925074a0ec0ed2cb654a0437d63c6ccf76a9da635900d"}, {file = "langsmith-0.2.4.tar.gz", hash = "sha256:386fedc815b45f94fa17571860e561c0d79facc0a2979a532eb8b893a4d98fa9"},
] ]
[package.dependencies] [package.dependencies]
@ -1647,4 +1647,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "1bc76c2d222284109d4dd1bb79d309c4157cb9fe9d426870183cb32fa16465be" content-hash = "71de53990a6cfb9cd6a25249b40eeef52e089840a9a06b54ac556fe7fa60504c"

View File

@ -24,7 +24,7 @@ ignore_missing_imports = true
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
langchain-core = "^0.3.27" langchain-core = "^0.3.27"
openai = "^1.55.3" openai = "^1.58.1"
tiktoken = ">=0.7,<1" tiktoken = ">=0.7,<1"
[tool.ruff.lint] [tool.ruff.lint]

View File

@ -1089,19 +1089,13 @@ async def test_astream_response_format() -> None:
pass pass
def test_o1_max_tokens() -> None: @pytest.mark.parametrize("use_max_completion_tokens", [True, False])
response = ChatOpenAI(model="o1-mini", max_tokens=10).invoke("how are you") # type: ignore[call-arg] def test_o1(use_max_completion_tokens: bool) -> None:
assert isinstance(response, AIMessage) if use_max_completion_tokens:
kwargs: dict = {"max_completion_tokens": 10}
response = ChatOpenAI(model="gpt-4o", max_completion_tokens=10).invoke( else:
"how are you" kwargs = {"max_tokens": 10}
) response = ChatOpenAI(model="o1", reasoning_effort="low", **kwargs).invoke(
assert isinstance(response, AIMessage)
def test_developer_message() -> None:
llm = ChatOpenAI(model="o1", max_tokens=10) # type: ignore[call-arg]
response = llm.invoke(
[ [
{"role": "developer", "content": "respond in all caps"}, {"role": "developer", "content": "respond in all caps"},
{"role": "user", "content": "HOW ARE YOU"}, {"role": "user", "content": "HOW ARE YOU"},

View File

@ -882,3 +882,9 @@ def test__get_request_payload() -> None:
} }
payload = llm._get_request_payload(messages) payload = llm._get_request_payload(messages)
assert payload == expected assert payload == expected
def test_init_o1() -> None:
with pytest.warns(None) as record: # type: ignore[call-overload]
ChatOpenAI(model="o1", reasoning_effort="medium")
assert len(record) == 0