mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
Fetch up-to-date attributes for env-pulled kwargs during serialisation of OpenAI classes (#11499)
This commit is contained in:
parent
c3d2b01adf
commit
484947c492
@ -141,6 +141,13 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "azure-openai-chat"
|
return "azure-openai-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"openai_api_type": self.openai_api_type,
|
||||||
|
"openai_api_version": self.openai_api_version,
|
||||||
|
}
|
||||||
|
|
||||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
for res in response["choices"]:
|
for res in response["choices"]:
|
||||||
if res.get("finish_reason", None) == "content_filter":
|
if res.get("finish_reason", None) == "content_filter":
|
||||||
|
@ -141,6 +141,21 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if self.openai_organization != "":
|
||||||
|
attributes["openai_organization"] = self.openai_organization
|
||||||
|
|
||||||
|
if self.openai_api_base != "":
|
||||||
|
attributes["openai_api_base"] = self.openai_api_base
|
||||||
|
|
||||||
|
if self.openai_proxy != "":
|
||||||
|
attributes["openai_proxy"] = self.openai_proxy
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
"""Return whether this model can be serialized by Langchain."""
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
|
@ -138,6 +138,20 @@ class BaseOpenAI(BaseLLM):
|
|||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
if self.openai_api_base != "":
|
||||||
|
attributes["openai_api_base"] = self.openai_api_base
|
||||||
|
|
||||||
|
if self.openai_organization != "":
|
||||||
|
attributes["openai_organization"] = self.openai_organization
|
||||||
|
|
||||||
|
if self.openai_proxy != "":
|
||||||
|
attributes["openai_proxy"] = self.openai_proxy
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -692,6 +706,13 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "azure"
|
return "azure"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"openai_api_type": self.openai_api_type,
|
||||||
|
"openai_api_version": self.openai_api_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChat(BaseLLM):
|
class OpenAIChat(BaseLLM):
|
||||||
"""OpenAI Chat large language models.
|
"""OpenAI Chat large language models.
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Mapping, cast
|
from typing import Any, Mapping, cast
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = "test"
|
|
||||||
os.environ["OPENAI_API_BASE"] = "https://oai.azure.com/"
|
|
||||||
os.environ["OPENAI_API_VERSION"] = "2023-05-01"
|
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"OPENAI_API_KEY": "test",
|
||||||
|
"OPENAI_API_BASE": "https://oai.azure.com/",
|
||||||
|
"OPENAI_API_VERSION": "2023-05-01",
|
||||||
|
},
|
||||||
|
)
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]
|
"model_name", ["gpt-4", "gpt-4-32k", "gpt-35-turbo", "gpt-35-turbo-16k"]
|
||||||
|
Loading…
Reference in New Issue
Block a user