mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(openai): infer azure chat profiles from model name (#36858)
This commit is contained in:
@@ -705,6 +705,10 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def _resolve_model_profile(self) -> ModelProfile | None:
|
def _resolve_model_profile(self) -> ModelProfile | None:
|
||||||
|
if (self.model_name is not None) and (
|
||||||
|
profile := _get_default_model_profile(self.model_name) or None
|
||||||
|
):
|
||||||
|
return profile
|
||||||
if self.deployment_name is not None:
|
if self.deployment_name is not None:
|
||||||
return _get_default_model_profile(self.deployment_name) or None
|
return _get_default_model_profile(self.deployment_name) or None
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ def test_initialize_azure_openai() -> None:
|
|||||||
azure_deployment="35-turbo-dev",
|
azure_deployment="35-turbo-dev",
|
||||||
openai_api_version="2023-05-15",
|
openai_api_version="2023-05-15",
|
||||||
azure_endpoint="my-base-url",
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
)
|
)
|
||||||
assert llm.deployment_name == "35-turbo-dev"
|
assert llm.deployment_name == "35-turbo-dev"
|
||||||
assert llm.openai_api_version == "2023-05-15"
|
assert llm.openai_api_version == "2023-05-15"
|
||||||
@@ -45,6 +46,92 @@ def test_initialize_more() -> None:
|
|||||||
assert ls_params.get("ls_model_name") == "gpt-35-turbo-0125"
|
assert ls_params.get("ls_model_name") == "gpt-35-turbo-0125"
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_resolves_from_model_name() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
model="gpt-4o",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile
|
||||||
|
assert llm.profile["name"] == "GPT-4o"
|
||||||
|
assert llm.profile["max_input_tokens"] == 128_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_resolves_from_model_name_with_custom_deployment_alias() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
model="gpt-4o",
|
||||||
|
azure_deployment="35-turbo-dev",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile
|
||||||
|
assert llm.profile["name"] == "GPT-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_prefers_model_name_over_known_deployment_name() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
model="gpt-4o",
|
||||||
|
azure_deployment="gpt-4",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile
|
||||||
|
assert llm.profile["name"] == "GPT-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_falls_back_to_deployment_name_with_unknown_model() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
model="unknown-model",
|
||||||
|
azure_deployment="gpt-4o",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_resolves_from_deployment_name_without_model() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
azure_deployment="gpt-4o",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile
|
||||||
|
assert llm.profile["name"] == "GPT-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_respects_explicit_profile() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
model="gpt-4o",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
profile={"tool_calling": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile == {"tool_calling": False}
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_is_none_for_unknown_deployment_without_model() -> None:
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
azure_deployment="unknown-deployment",
|
||||||
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
|
api_version="2023-05-15",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.profile is None
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_azure_openai_with_openai_api_base_set() -> None:
|
def test_initialize_azure_openai_with_openai_api_base_set() -> None:
|
||||||
with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
|
with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}):
|
||||||
llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg]
|
llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg]
|
||||||
@@ -79,6 +166,7 @@ def test_structured_output_old_model() -> None:
|
|||||||
azure_deployment="35-turbo-dev",
|
azure_deployment="35-turbo-dev",
|
||||||
openai_api_version="2023-05-15",
|
openai_api_version="2023-05-15",
|
||||||
azure_endpoint="my-base-url",
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
).with_structured_output(Output)
|
).with_structured_output(Output)
|
||||||
|
|
||||||
# assert tool calling was used instead of json_schema
|
# assert tool calling was used instead of json_schema
|
||||||
@@ -91,6 +179,7 @@ def test_max_completion_tokens_in_payload() -> None:
|
|||||||
azure_deployment="o1-mini",
|
azure_deployment="o1-mini",
|
||||||
api_version="2024-12-01-preview",
|
api_version="2024-12-01-preview",
|
||||||
azure_endpoint="my-base-url",
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
model_kwargs={"max_completion_tokens": 300},
|
model_kwargs={"max_completion_tokens": 300},
|
||||||
)
|
)
|
||||||
messages = [HumanMessage("Hello")]
|
messages = [HumanMessage("Hello")]
|
||||||
@@ -148,6 +237,7 @@ def test_max_completion_tokens_parameter() -> None:
|
|||||||
azure_deployment="gpt-5",
|
azure_deployment="gpt-5",
|
||||||
api_version="2024-12-01-preview",
|
api_version="2024-12-01-preview",
|
||||||
azure_endpoint="my-base-url",
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
max_completion_tokens=1500,
|
max_completion_tokens=1500,
|
||||||
)
|
)
|
||||||
messages = [HumanMessage("Hello")]
|
messages = [HumanMessage("Hello")]
|
||||||
@@ -165,6 +255,7 @@ def test_max_tokens_converted_to_max_completion_tokens() -> None:
|
|||||||
azure_deployment="gpt-5",
|
azure_deployment="gpt-5",
|
||||||
api_version="2024-12-01-preview",
|
api_version="2024-12-01-preview",
|
||||||
azure_endpoint="my-base-url",
|
azure_endpoint="my-base-url",
|
||||||
|
api_key=SecretStr("test"),
|
||||||
max_tokens=1000, # type: ignore[call-arg]
|
max_tokens=1000, # type: ignore[call-arg]
|
||||||
)
|
)
|
||||||
messages = [HumanMessage("Hello")]
|
messages = [HumanMessage("Hello")]
|
||||||
|
|||||||
Reference in New Issue
Block a user