mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
openai[patch]: use model_name in AzureOpenAI.ls_model_name (#24366)
This commit is contained in:
parent
eb26b5535a
commit
7d83189b19
@ -95,6 +95,12 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
organization: Optional[str]
|
||||
OpenAI organization ID. If not passed in will be read from env
|
||||
var OPENAI_ORG_ID.
|
||||
model: Optional[str]
|
||||
The name of the underlying OpenAI model. Used for tracing and token
|
||||
counting. Does not affect completion. E.g. "gpt-4", "gpt-35-turbo", etc.
|
||||
model_version: Optional[str]
|
||||
The version of the underlying OpenAI model. Used for tracing and token
|
||||
counting. Does not affect completion. E.g., "0125", "0125-preview", etc.
|
||||
|
||||
See full list of supported init args and their descriptions in the params section.
|
||||
|
||||
@ -111,6 +117,8 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
# organization="...",
|
||||
# model="gpt-35-turbo",
|
||||
# model_version="0125",
|
||||
# other params...
|
||||
)
|
||||
|
||||
@ -514,6 +522,13 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
azure_endpoint and update client params accordingly.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
|
||||
"""Name of the deployed OpenAI model, e.g. "gpt-4o", "gpt-35-turbo", etc.
|
||||
|
||||
Distinct from the Azure deployment name, which is set by the Azure user.
|
||||
Used for tracing and token counting. Does NOT affect completion.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@ -923,7 +938,14 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "azure"
|
||||
if self.deployment_name:
|
||||
if self.model_name:
|
||||
if self.model_version and self.model_version not in self.model_name:
|
||||
params["ls_model_name"] = (
|
||||
self.model_name + "-" + self.model_version.lstrip("-")
|
||||
)
|
||||
else:
|
||||
params["ls_model_name"] = self.model_name
|
||||
elif self.deployment_name:
|
||||
params["ls_model_name"] = self.deployment_name
|
||||
return params
|
||||
|
||||
|
@ -23,6 +23,8 @@ def test_initialize_more() -> None:
|
||||
azure_deployment="35-turbo-dev",
|
||||
openai_api_version="2023-05-15",
|
||||
temperature=0,
|
||||
model="gpt-35-turbo",
|
||||
model_version="0125",
|
||||
)
|
||||
assert llm.openai_api_key is not None
|
||||
assert llm.openai_api_key.get_secret_value() == "xyz"
|
||||
@ -33,7 +35,7 @@ def test_initialize_more() -> None:
|
||||
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_provider"] == "azure"
|
||||
assert ls_params["ls_model_name"] == "35-turbo-dev"
|
||||
assert ls_params["ls_model_name"] == "gpt-35-turbo-0125"
|
||||
|
||||
|
||||
def test_initialize_azure_openai_with_openai_api_base_set() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user