mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 06:53:59 +00:00
openai[patch]: fix azure open lc serialization, release 0.1.5 (#21159)
This commit is contained in:
@@ -10,12 +10,12 @@ from langchain_core.outputs import ChatResult
|
|||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
from langchain_openai.chat_models.base import ChatOpenAI
|
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AzureChatOpenAI(ChatOpenAI):
|
class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
"""`Azure OpenAI` Chat Completion API.
|
"""`Azure OpenAI` Chat Completion API.
|
||||||
|
|
||||||
To use this class you
|
To use this class you
|
||||||
@@ -100,6 +100,17 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "chat_models", "azure_openai"]
|
return ["langchain", "chat_models", "azure_openai"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"openai_api_key": "AZURE_OPENAI_API_KEY",
|
||||||
|
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
@@ -291,52 +291,7 @@ class _AllReturnType(TypedDict):
|
|||||||
parsing_error: Optional[BaseException]
|
parsing_error: Optional[BaseException]
|
||||||
|
|
||||||
|
|
||||||
class ChatOpenAI(BaseChatModel):
|
class BaseChatOpenAI(BaseChatModel):
|
||||||
"""`OpenAI` Chat large language models API.
|
|
||||||
|
|
||||||
To use, you should have the environment variable ``OPENAI_API_KEY``
|
|
||||||
set with your API key, or pass it as a named parameter to the constructor.
|
|
||||||
|
|
||||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
||||||
in, even if not explicitly saved on this class.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
model = ChatOpenAI(model="gpt-3.5-turbo")
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
|
||||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
|
||||||
"""Get the namespace of the langchain object."""
|
|
||||||
return ["langchain", "chat_models", "openai"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_attributes(self) -> Dict[str, Any]:
|
|
||||||
attributes: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
if self.openai_organization:
|
|
||||||
attributes["openai_organization"] = self.openai_organization
|
|
||||||
|
|
||||||
if self.openai_api_base:
|
|
||||||
attributes["openai_api_base"] = self.openai_api_base
|
|
||||||
|
|
||||||
if self.openai_proxy:
|
|
||||||
attributes["openai_proxy"] = self.openai_proxy
|
|
||||||
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
"""Return whether this model can be serialized by Langchain."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||||
@@ -1093,6 +1048,53 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
return llm | output_parser
|
return llm | output_parser
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOpenAI(BaseChatOpenAI):
|
||||||
|
"""`OpenAI` Chat large language models API.
|
||||||
|
|
||||||
|
To use, you should have the environment variable ``OPENAI_API_KEY``
|
||||||
|
set with your API key, or pass it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||||
|
in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
model = ChatOpenAI(model="gpt-3.5-turbo")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
"""Get the namespace of the langchain object."""
|
||||||
|
return ["langchain", "chat_models", "openai"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if self.openai_organization:
|
||||||
|
attributes["openai_organization"] = self.openai_organization
|
||||||
|
|
||||||
|
if self.openai_api_base:
|
||||||
|
attributes["openai_api_base"] = self.openai_api_base
|
||||||
|
|
||||||
|
if self.openai_proxy:
|
||||||
|
attributes["openai_proxy"] = self.openai_proxy
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||||
|
|
||||||
|
@@ -72,6 +72,18 @@ class AzureOpenAI(BaseOpenAI):
|
|||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
return ["langchain", "llms", "openai"]
|
return ["langchain", "llms", "openai"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"openai_api_key": "AZURE_OPENAI_API_KEY",
|
||||||
|
"azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
|
return True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
@@ -68,24 +68,6 @@ def _stream_response_to_generation_chunk(
|
|||||||
class BaseOpenAI(BaseLLM):
|
class BaseOpenAI(BaseLLM):
|
||||||
"""Base OpenAI large language model class."""
|
"""Base OpenAI large language model class."""
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
|
||||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_attributes(self) -> Dict[str, Any]:
|
|
||||||
attributes: Dict[str, Any] = {}
|
|
||||||
if self.openai_api_base:
|
|
||||||
attributes["openai_api_base"] = self.openai_api_base
|
|
||||||
|
|
||||||
if self.openai_organization:
|
|
||||||
attributes["openai_organization"] = self.openai_organization
|
|
||||||
|
|
||||||
if self.openai_proxy:
|
|
||||||
attributes["openai_proxy"] = self.openai_proxy
|
|
||||||
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(default="gpt-3.5-turbo-instruct", alias="model")
|
model_name: str = Field(default="gpt-3.5-turbo-instruct", alias="model")
|
||||||
@@ -649,3 +631,21 @@ class OpenAI(BaseOpenAI):
|
|||||||
@property
|
@property
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
def _invocation_params(self) -> Dict[str, Any]:
|
||||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
return {**{"model": self.model_name}, **super()._invocation_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
if self.openai_api_base:
|
||||||
|
attributes["openai_api_base"] = self.openai_api_base
|
||||||
|
|
||||||
|
if self.openai_organization:
|
||||||
|
attributes["openai_organization"] = self.openai_organization
|
||||||
|
|
||||||
|
if self.openai_proxy:
|
||||||
|
attributes["openai_proxy"] = self.openai_proxy
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-openai"
|
name = "langchain-openai"
|
||||||
version = "0.1.4"
|
version = "0.1.5"
|
||||||
description = "An integration package connecting OpenAI and LangChain"
|
description = "An integration package connecting OpenAI and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from typing import Type, cast
|
from typing import Type, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.pydantic_v1 import SecretStr
|
from langchain_core.pydantic_v1 import SecretStr
|
||||||
from pytest import CaptureFixture, MonkeyPatch
|
from pytest import CaptureFixture, MonkeyPatch
|
||||||
|
|
||||||
@@ -187,3 +188,19 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: Type) -> No
|
|||||||
"""Test that the actual secret value is correctly retrieved."""
|
"""Test that the actual secret value is correctly retrieved."""
|
||||||
model = model_class(openai_api_key="secret-api-key")
|
model = model_class(openai_api_key="secret-api-key")
|
||||||
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
|
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
|
||||||
|
def test_azure_serialized_secrets(model_class: Type) -> None:
|
||||||
|
"""Test that the actual secret value is correctly retrieved."""
|
||||||
|
model = model_class(
|
||||||
|
openai_api_key="secret-api-key", api_version="foo", azure_endpoint="foo"
|
||||||
|
)
|
||||||
|
serialized = dumpd(model)
|
||||||
|
assert serialized["kwargs"]["openai_api_key"]["id"] == ["AZURE_OPENAI_API_KEY"]
|
||||||
|
|
||||||
|
model = model_class(
|
||||||
|
azure_ad_token="secret-token", api_version="foo", azure_endpoint="foo"
|
||||||
|
)
|
||||||
|
serialized = dumpd(model)
|
||||||
|
assert serialized["kwargs"]["azure_ad_token"]["id"] == ["AZURE_OPENAI_AD_TOKEN"]
|
||||||
|
Reference in New Issue
Block a user