openai[patch]: azure max completion tokens fix

This commit is contained in:
Bagatur 2024-11-26 15:55:47 -08:00
parent 8adc4a5bcc
commit eec4df0d0a
2 changed files with 22 additions and 0 deletions

View File

@ -737,3 +737,15 @@ class AzureChatOpenAI(BaseChatOpenAI):
)
return chat_result
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
params = super()._default_params
if (
"o1" in params["model"]
and "max_tokens" in params["model"]
and "max_completion_tokens" not in params["model"]
):
params["max_completion_tokens"] = params.pop("max_tokens")
return params

View File

@ -262,3 +262,13 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None:
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}
def test_o1_max_tokens() -> None:
response = ChatOpenAI(model="o1-mini", max_tokens=10).invoke("how are you") # type: ignore[call-arg]
assert isinstance(response, AIMessage)
response = ChatOpenAI(model="gpt-4o", max_completion_tokens=10).invoke(
"how are you"
)
assert isinstance(response, AIMessage)