fix: core: Include in json output also fields set outside the constructor (#21342)

This commit is contained in:
Nuno Campos 2024-05-06 14:37:36 -07:00 committed by GitHub
parent ac14f171ac
commit 6f17158606
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 480 additions and 352 deletions

View File

@ -12,7 +12,7 @@ from typing import (
from typing_extensions import NotRequired
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from langchain_core.pydantic_v1 import BaseModel
class BaseSerialized(TypedDict):
@ -114,12 +114,6 @@ class Serializable(BaseModel, ABC):
if (k not in self.__fields__ or try_neq_default(v, k, self))
]
_lc_kwargs = PrivateAttr(default_factory=dict)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.is_lc_serializable():
return self.to_json_not_implemented()
@ -128,8 +122,9 @@ class Serializable(BaseModel, ABC):
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self._lc_kwargs.items()
for k, v in self
if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
and _is_field_useful(self, k, v)
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
@ -186,6 +181,23 @@ class Serializable(BaseModel, ABC):
return to_json_not_implemented(self)
def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
"""Check if a field is useful as a constructor argument.
Args:
inst: The instance.
key: The key.
value: The value.
Returns:
Whether the field is useful.
"""
field = inst.__fields__.get(key)
if not field:
return False
return field.required is True or value or field.get_default() != value
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:

View File

@ -1,9 +1,11 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence
from uuid import UUID
import pytest
from pytest import Config, Function, Parser
from pytest_mock import MockerFixture
def pytest_addoption(parser: Parser) -> None:
@ -85,3 +87,11 @@ def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) ->
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test.")
)
@pytest.fixture()
def deterministic_uuids(mocker: MockerFixture) -> MockerFixture:
side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)
return mocker.patch("uuid.uuid4", side_effect=side_effect)

View File

@ -21,6 +21,7 @@ def test_serdes_message() -> None:
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"type": "ai",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [
@ -46,6 +47,7 @@ def test_serdes_message_chunk() -> None:
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
"kwargs": {
"type": "AIMessageChunk",
"content": [{"text": "blah", "type": "text"}],
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
"invalid_tool_calls": [

View File

@ -74,7 +74,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
@ -109,8 +108,7 @@
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
@ -151,7 +149,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
@ -200,8 +197,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
@ -284,8 +280,7 @@
"buz"
],
"template": "what did baz say to {buz}",
"template_format": "f-string",
"partial_variables": {}
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
@ -326,7 +321,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "not_implemented",
@ -375,8 +369,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
@ -445,8 +438,7 @@
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
]
},
"name": "RunnableWithFallbacks",
"graph": {
@ -486,8 +478,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
@ -579,11 +570,7 @@
"runnable",
"RunnablePassthrough"
],
"kwargs": {
"func": null,
"afunc": null,
"input_type": null
},
"kwargs": {},
"name": "RunnablePassthrough",
"graph": {
"nodes": [
@ -664,7 +651,6 @@
]
}
},
"middle": [],
"last": {
"lc": 1,
"type": "constructor",
@ -750,8 +736,7 @@
}
]
}
},
"name": null
}
},
"name": "RunnableSequence",
"graph": {
@ -933,8 +918,7 @@
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
]
},
"name": "RunnableWithFallbacks",
"graph": {
@ -1148,8 +1132,7 @@
],
"repr": "<class 'Exception'>"
}
],
"exception_key": null
]
},
"name": "RunnableWithFallbacks",
"graph": {

File diff suppressed because one or more lines are too long

View File

@ -48,6 +48,8 @@ from langchain_core.output_parsers import (
CommaSeparatedListOutputParser,
StrOutputParser,
)
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.prompts import (
ChatPromptTemplate,
@ -111,7 +113,19 @@ class FakeTracer(BaseTracer):
def _replace_message_id(self, maybe_message: Any) -> Any:
if isinstance(maybe_message, BaseMessage):
maybe_message.id = AnyStr()
maybe_message.id = str(next(self.uuids_generator))
if isinstance(maybe_message, ChatGeneration):
maybe_message.message.id = str(next(self.uuids_generator))
if isinstance(maybe_message, LLMResult):
for i, gen_list in enumerate(maybe_message.generations):
for j, gen in enumerate(gen_list):
maybe_message.generations[i][j] = self._replace_message_id(gen)
if isinstance(maybe_message, dict):
for k, v in maybe_message.items():
maybe_message[k] = self._replace_message_id(v)
if isinstance(maybe_message, list):
for i, v in enumerate(maybe_message):
maybe_message[i] = self._replace_message_id(v)
return maybe_message
@ -136,16 +150,8 @@ class FakeTracer(BaseTracer):
"child_runs": [self._copy_run(child) for child in run.child_runs],
"trace_id": self._replace_uuid(run.trace_id) if run.trace_id else None,
"dotted_order": new_dotted_order,
"inputs": (
{k: self._replace_message_id(v) for k, v in run.inputs.items()}
if isinstance(run.inputs, dict)
else run.inputs
),
"outputs": (
{k: self._replace_message_id(v) for k, v in run.outputs.items()}
if isinstance(run.outputs, dict)
else run.outputs
),
"inputs": self._replace_message_id(run.inputs),
"outputs": self._replace_message_id(run.outputs),
}
)
@ -1939,7 +1945,9 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
@freeze_time("2023-01-01")
def test_prompt_with_chat_model(
mocker: MockerFixture, snapshot: SnapshotAssertion
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2043,7 +2051,9 @@ def test_prompt_with_chat_model(
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model_async(
mocker: MockerFixture, snapshot: SnapshotAssertion
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2770,7 +2780,9 @@ async def test_prompt_with_llm_and_async_lambda(
@freeze_time("2023-01-01")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -2809,7 +2821,9 @@ def test_prompt_with_chat_model_and_parser(
@freeze_time("2023-01-01")
def test_combining_sequences(
mocker: MockerFixture, snapshot: SnapshotAssertion
mocker: MockerFixture,
snapshot: SnapshotAssertion,
deterministic_uuids: MockerFixture,
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")

View File

@ -36,13 +36,6 @@
"SpecialPerson"
],
"kwargs": {
"another_secret": {
"lc": 1,
"type": "secret",
"id": [
"ANOTHER_SECRET"
]
},
"secret": {
"lc": 1,
"type": "secret",
@ -50,8 +43,15 @@
"SECRET"
]
},
"another_visible": "bye",
"you_can_see_me": "hello"
"you_can_see_me": "hello",
"another_secret": {
"lc": 1,
"type": "secret",
"id": [
"ANOTHER_SECRET"
]
},
"another_visible": "bye"
}
}
'''
@ -71,6 +71,61 @@
"LLMChain"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"name": "PromptTemplate"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"llm": {
"lc": 1,
"type": "constructor",
@ -81,15 +136,23 @@
"OpenAI"
],
"kwargs": {
"model": "davinci",
"model_name": "davinci",
"temperature": 0.5,
"max_tokens": 256,
"top_p": 1,
"n": 1,
"best_of": 1,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
},
"openai_proxy": "",
"batch_size": 20,
"max_retries": 2,
"disallowed_special": "all"
},
"name": "OpenAI",
"graph": {
@ -130,30 +193,24 @@
]
}
},
"prompt": {
"output_key": "text",
"output_parser": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
"schema",
"output_parser",
"StrOutputParser"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string",
"partial_variables": {}
},
"name": "PromptTemplate",
"kwargs": {},
"name": "StrOutputParser",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
"data": "StrOutputParserInput"
},
{
"id": 1,
@ -161,17 +218,17 @@
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
"schema",
"output_parser",
"StrOutputParser"
],
"name": "PromptTemplate"
"name": "StrOutputParser"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
"data": "StrOutputParserOutput"
}
],
"edges": [
@ -185,7 +242,8 @@
}
]
}
}
},
"return_final_only": true
},
"name": "LLMChain",
"graph": {
@ -240,65 +298,6 @@
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
},
"name": "ChatOpenAI",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "ChatOpenAIInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"name": "ChatOpenAI"
}
},
{
"id": 2,
"type": "schema",
"data": "ChatOpenAIOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"prompt": {
"lc": 1,
"type": "constructor",
@ -337,8 +336,7 @@
"name"
],
"template": "hello {name}!",
"template_format": "f-string",
"partial_variables": {}
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
@ -381,8 +379,7 @@
}
}
}
],
"partial_variables": {}
]
},
"name": "ChatPromptTemplate",
"graph": {
@ -422,8 +419,121 @@
}
]
}
},
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model_name": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"openai_proxy": "",
"max_retries": 2,
"n": 1
},
"name": "ChatOpenAI",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "ChatOpenAIInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"name": "ChatOpenAI"
}
},
{
"id": 2,
"type": "schema",
"data": "ChatOpenAIOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"output_key": "text",
"output_parser": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"kwargs": {},
"name": "StrOutputParser",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "StrOutputParserInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"name": "StrOutputParser"
}
},
{
"id": 2,
"type": "schema",
"data": "StrOutputParserOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"return_final_only": true
},
"name": "LLMChain",
"graph": {
"nodes": [
@ -477,6 +587,61 @@
"LLMChain"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"name": "PromptTemplate"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"llm": {
"lc": 1,
"type": "constructor",
@ -487,15 +652,23 @@
"OpenAI"
],
"kwargs": {
"model": "davinci",
"model_name": "davinci",
"temperature": 0.5,
"max_tokens": 256,
"top_p": 1,
"n": 1,
"best_of": 1,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
},
"openai_proxy": "",
"batch_size": 20,
"max_retries": 2,
"disallowed_special": "all"
},
"name": "OpenAI",
"graph": {
@ -536,30 +709,24 @@
]
}
},
"prompt": {
"output_key": "text",
"output_parser": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
"schema",
"output_parser",
"StrOutputParser"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string",
"partial_variables": {}
},
"name": "PromptTemplate",
"kwargs": {},
"name": "StrOutputParser",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
"data": "StrOutputParserInput"
},
{
"id": 1,
@ -567,17 +734,17 @@
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
"schema",
"output_parser",
"StrOutputParser"
],
"name": "PromptTemplate"
"name": "StrOutputParser"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
"data": "StrOutputParserOutput"
}
],
"edges": [
@ -591,7 +758,8 @@
}
]
}
}
},
"return_final_only": true
},
"name": "LLMChain",
"graph": {
@ -646,15 +814,23 @@
"OpenAI"
],
"kwargs": {
"model": "davinci",
"model_name": "davinci",
"temperature": 0.7,
"max_tokens": 256,
"top_p": 1,
"n": 1,
"best_of": 1,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
},
"openai_proxy": "",
"batch_size": 20,
"max_retries": 2,
"disallowed_special": "all"
},
"name": "OpenAI",
"graph": {

View File

@ -242,11 +242,6 @@ def test_aliases_hidden() -> None:
"type": "secret",
"id": ["MY_FAVORITE_SECRET"],
},
"my_favorite_secret_alias": {
"lc": 1,
"type": "secret",
"id": ["MY_FAVORITE_SECRET"],
},
"my_other_secret": {"lc": 1, "type": "secret", "id": ["MY_OTHER_SECRET"]},
},
}

View File

@ -17,7 +17,9 @@ class NotSerializable:
def test_loads_openai_llm() -> None:
from langchain_openai import OpenAI
llm = CommunityOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
llm = CommunityOpenAI(
model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8
) # type: ignore[call-arg]
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
@ -31,7 +33,9 @@ def test_loads_openai_llm() -> None:
def test_loads_llmchain() -> None:
from langchain_openai import OpenAI
llm = CommunityOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
llm = CommunityOpenAI(
model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8
) # type: ignore[call-arg]
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)
@ -54,7 +58,7 @@ def test_loads_llmchain_env() -> None:
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm = OpenAI(model="davinci", temperature=0.5) # type: ignore[call-arg]
llm = OpenAI(model="davinci", temperature=0.5, top_p=0.8) # type: ignore[call-arg]
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)