openai[patch]: use model_name in AzureOpenAI.ls_model_name (#24366)

This commit is contained in:
Bagatur 2024-07-17 15:24:05 -07:00 committed by GitHub
parent eb26b5535a
commit 7d83189b19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 2 deletions

View File

@ -95,6 +95,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
var OPENAI_ORG_ID.
model: Optional[str]
The name of the underlying OpenAI model. Used for tracing and token
counting. Does not affect completion. E.g. "gpt-4", "gpt-35-turbo", etc.
model_version: Optional[str]
The version of the underlying OpenAI model. Used for tracing and token
counting. Does not affect completion. E.g., "0125", "0125-preview", etc.
See full list of supported init args and their descriptions in the params section.
@ -111,6 +117,8 @@ class AzureChatOpenAI(BaseChatOpenAI):
timeout=None,
max_retries=2,
# organization="...",
# model="gpt-35-turbo",
# model_version="0125",
# other params...
)
@ -514,6 +522,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
azure_endpoint and update client params accordingly.
"""
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
"""Name of the deployed OpenAI model, e.g. "gpt-4o", "gpt-35-turbo", etc.
Distinct from the Azure deployment name, which is set by the Azure user.
Used for tracing and token counting. Does NOT affect completion.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
@ -923,7 +938,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
"""Get the parameters used to invoke the model."""
params = super()._get_ls_params(stop=stop, **kwargs)
params["ls_provider"] = "azure"
if self.deployment_name:
if self.model_name:
if self.model_version and self.model_version not in self.model_name:
params["ls_model_name"] = (
self.model_name + "-" + self.model_version.lstrip("-")
)
else:
params["ls_model_name"] = self.model_name
elif self.deployment_name:
params["ls_model_name"] = self.deployment_name
return params

View File

@ -23,6 +23,8 @@ def test_initialize_more() -> None:
azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15",
temperature=0,
model="gpt-35-turbo",
model_version="0125",
)
assert llm.openai_api_key is not None
assert llm.openai_api_key.get_secret_value() == "xyz"
@ -33,7 +35,7 @@ def test_initialize_more() -> None:
ls_params = llm._get_ls_params()
assert ls_params["ls_provider"] == "azure"
assert ls_params["ls_model_name"] == "35-turbo-dev"
assert ls_params["ls_model_name"] == "gpt-35-turbo-0125"
def test_initialize_azure_openai_with_openai_api_base_set() -> None: