mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 09:58:44 +00:00
core[patch]: Hide aliases when serializing (#16888)
Currently, if you dump an object initialized with an alias, we'll still dump the secret values since they're retained in the kwargs
This commit is contained in:
parent
131c043864
commit
e02efd513f
@ -143,6 +143,14 @@ class Serializable(BaseModel, ABC):
|
|||||||
this = cast(Serializable, self if cls is None else super(cls, self))
|
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||||
|
|
||||||
secrets.update(this.lc_secrets)
|
secrets.update(this.lc_secrets)
|
||||||
|
# Now also add the aliases for the secrets
|
||||||
|
# This ensures known secret aliases are hidden.
|
||||||
|
# Note: this does NOT hide any other extra kwargs
|
||||||
|
# that are not present in the fields.
|
||||||
|
for key in list(secrets):
|
||||||
|
value = secrets[key]
|
||||||
|
if key in this.__fields__:
|
||||||
|
secrets[this.__fields__[key].alias] = value
|
||||||
lc_kwargs.update(this.lc_attributes)
|
lc_kwargs.update(this.lc_attributes)
|
||||||
|
|
||||||
# include all secrets, even if not specified in kwargs
|
# include all secrets, even if not specified in kwargs
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test for Serializable base class"""
|
"""Test for Serializable base class"""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -11,6 +12,7 @@ from langchain_core.load.dump import dumps
|
|||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import Field, root_validator
|
||||||
from langchain_core.tracers.langchain import LangChainTracer
|
from langchain_core.tracers.langchain import LangChainTracer
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -167,3 +169,85 @@ def test_person_with_invalid_kwargs() -> None:
|
|||||||
person = Person(secret="hello")
|
person = Person(secret="hello")
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
dumps(person, invalid_kwarg="hello")
|
dumps(person, invalid_kwarg="hello")
|
||||||
|
|
||||||
|
|
||||||
|
class TestClass(Serializable):
|
||||||
|
my_favorite_secret: str = Field(alias="my_favorite_secret_alias")
|
||||||
|
my_other_secret: str = Field()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def get_from_env(cls, values: Dict) -> Dict:
|
||||||
|
"""Get the values from the environment."""
|
||||||
|
if "my_favorite_secret" not in values:
|
||||||
|
values["my_favorite_secret"] = os.getenv("MY_FAVORITE_SECRET")
|
||||||
|
if "my_other_secret" not in values:
|
||||||
|
values["my_other_secret"] = os.getenv("MY_OTHER_SECRET")
|
||||||
|
return values
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
|
return ["my", "special", "namespace"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"my_favorite_secret": "MY_FAVORITE_SECRET",
|
||||||
|
"my_other_secret": "MY_OTHER_SECRET",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_aliases_hidden() -> None:
|
||||||
|
test_class = TestClass(my_favorite_secret="hello", my_other_secret="world")
|
||||||
|
dumped = json.loads(dumps(test_class, pretty=True))
|
||||||
|
expected_dump = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["my", "special", "namespace", "TestClass"],
|
||||||
|
"kwargs": {
|
||||||
|
"my_favorite_secret": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": ["MY_FAVORITE_SECRET"],
|
||||||
|
},
|
||||||
|
"my_other_secret": {"lc": 1, "type": "secret", "id": ["MY_OTHER_SECRET"]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert dumped == expected_dump
|
||||||
|
# Check while patching the os environment
|
||||||
|
with patch.dict(
|
||||||
|
os.environ, {"MY_FAVORITE_SECRET": "hello", "MY_OTHER_SECRET": "world"}
|
||||||
|
):
|
||||||
|
test_class = TestClass()
|
||||||
|
dumped = json.loads(dumps(test_class, pretty=True))
|
||||||
|
|
||||||
|
# Check by alias
|
||||||
|
test_class = TestClass(my_favorite_secret_alias="hello", my_other_secret="world")
|
||||||
|
dumped = json.loads(dumps(test_class, pretty=True))
|
||||||
|
expected_dump = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": ["my", "special", "namespace", "TestClass"],
|
||||||
|
"kwargs": {
|
||||||
|
"my_favorite_secret": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": ["MY_FAVORITE_SECRET"],
|
||||||
|
},
|
||||||
|
"my_favorite_secret_alias": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": ["MY_FAVORITE_SECRET"],
|
||||||
|
},
|
||||||
|
"my_other_secret": {"lc": 1, "type": "secret", "id": ["MY_OTHER_SECRET"]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert dumped == expected_dump
|
||||||
|
Loading…
Reference in New Issue
Block a user