mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Fetch up-to-date attributes for env-pulled kwargs during serialisation of OpenAI classes (#11499)
This commit is contained in:
parent
c3d2b01adf
commit
484947c492
@ -141,6 +141,13 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
def _llm_type(self) -> str:
|
||||
return "azure-openai-chat"
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_api_type": self.openai_api_type,
|
||||
"openai_api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
for res in response["choices"]:
|
||||
if res.get("finish_reason", None) == "content_filter":
|
||||
|
@ -141,6 +141,21 @@ class ChatOpenAI(BaseChatModel):
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.openai_organization != "":
|
||||
attributes["openai_organization"] = self.openai_organization
|
||||
|
||||
if self.openai_api_base != "":
|
||||
attributes["openai_api_base"] = self.openai_api_base
|
||||
|
||||
if self.openai_proxy != "":
|
||||
attributes["openai_proxy"] = self.openai_proxy
|
||||
|
||||
return attributes
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
|
@ -138,6 +138,20 @@ class BaseOpenAI(BaseLLM):
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if self.openai_api_base != "":
|
||||
attributes["openai_api_base"] = self.openai_api_base
|
||||
|
||||
if self.openai_organization != "":
|
||||
attributes["openai_organization"] = self.openai_organization
|
||||
|
||||
if self.openai_proxy != "":
|
||||
attributes["openai_proxy"] = self.openai_proxy
|
||||
|
||||
return attributes
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@ -692,6 +706,13 @@ class AzureOpenAI(BaseOpenAI):
|
||||
"""Return type of llm."""
|
||||
return "azure"
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_api_type": self.openai_api_type,
|
||||
"openai_api_version": self.openai_api_version,
|
||||
}
|
||||
|
||||
|
||||
class OpenAIChat(BaseLLM):
|
||||
"""OpenAI Chat large language models.
|
||||
|
@ -1,16 +1,21 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Mapping, cast
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "test"
|
||||
os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/"
|
||||
os.environ["OPENAI_API_VERSION"] = "2023-05-01"
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"OPENAI_API_KEY": "test",
|
||||
"OPENAI_API_BASE": "https://oai.azure.com/",
|
||||
"OPENAI_API_VERSION": "2023-05-01",
|
||||
},
|
||||
)
|
||||
@pytest.mark.requires("openai")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]
|
||||
|
Loading…
Reference in New Issue
Block a user