core[patch]: Fix changes to pydantic schema due to pydantic 2.8.2 -> 2.9 changes (#26166)

Minor non functional change in pydantic schema generation
This commit is contained in:
Eugene Yurtsev
2024-09-06 17:24:10 -04:00
committed by GitHub
parent b2ba4f4072
commit 61087b0c0d
5 changed files with 625 additions and 433 deletions

View File

@@ -0,0 +1,180 @@
# serializer version: 1
# name: test_mustache_prompt_from_template[schema_0]
dict({
'$defs': dict({
'obj': dict({
'properties': dict({
'bar': dict({
'title': 'Bar',
'type': 'string',
}),
'foo': dict({
'title': 'Foo',
'type': 'string',
}),
}),
'title': 'obj',
'type': 'object',
}),
}),
'properties': dict({
'foo': dict({
'title': 'Foo',
'type': 'string',
}),
'obj': dict({
'$ref': '#/$defs/obj',
}),
}),
'title': 'PromptInput',
'type': 'object',
})
# ---
# name: test_mustache_prompt_from_template[schema_2]
dict({
'$defs': dict({
'foo': dict({
'properties': dict({
'bar': dict({
'title': 'Bar',
'type': 'string',
}),
}),
'title': 'foo',
'type': 'object',
}),
}),
'properties': dict({
'foo': dict({
'$ref': '#/$defs/foo',
}),
}),
'title': 'PromptInput',
'type': 'object',
})
# ---
# name: test_mustache_prompt_from_template[schema_3]
dict({
'$defs': dict({
'baz': dict({
'properties': dict({
'qux': dict({
'title': 'Qux',
'type': 'string',
}),
}),
'title': 'baz',
'type': 'object',
}),
'foo': dict({
'properties': dict({
'bar': dict({
'title': 'Bar',
'type': 'string',
}),
'baz': dict({
'$ref': '#/$defs/baz',
}),
'quux': dict({
'title': 'Quux',
'type': 'string',
}),
}),
'title': 'foo',
'type': 'object',
}),
}),
'properties': dict({
'foo': dict({
'$ref': '#/$defs/foo',
}),
}),
'title': 'PromptInput',
'type': 'object',
})
# ---
# name: test_mustache_prompt_from_template[schema_4]
dict({
'$defs': dict({
'barfoo': dict({
'properties': dict({
'foobar': dict({
'title': 'Foobar',
'type': 'string',
}),
}),
'title': 'barfoo',
'type': 'object',
}),
'baz': dict({
'properties': dict({
'qux': dict({
'$ref': '#/$defs/qux',
}),
}),
'title': 'baz',
'type': 'object',
}),
'foo': dict({
'properties': dict({
'bar': dict({
'title': 'Bar',
'type': 'string',
}),
'baz': dict({
'$ref': '#/$defs/baz',
}),
'quux': dict({
'title': 'Quux',
'type': 'string',
}),
}),
'title': 'foo',
'type': 'object',
}),
'qux': dict({
'properties': dict({
'barfoo': dict({
'$ref': '#/$defs/barfoo',
}),
'foobar': dict({
'title': 'Foobar',
'type': 'string',
}),
}),
'title': 'qux',
'type': 'object',
}),
}),
'properties': dict({
'foo': dict({
'$ref': '#/$defs/foo',
}),
}),
'title': 'PromptInput',
'type': 'object',
})
# ---
# name: test_mustache_prompt_from_template[schema_5]
dict({
'$defs': dict({
'foo': dict({
'properties': dict({
'bar': dict({
'title': 'Bar',
'type': 'string',
}),
}),
'title': 'foo',
'type': 'object',
}),
}),
'properties': dict({
'foo': dict({
'$ref': '#/$defs/foo',
}),
}),
'title': 'PromptInput',
'type': 'object',
})
# ---

View File

@@ -3,10 +3,15 @@
from typing import Any, Dict, Union
from unittest import mock
import pydantic
import pytest
from syrupy import SnapshotAssertion
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from tests.unit_tests.pydantic_utils import _normalize_schema
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
def test_prompt_valid() -> None:
@@ -62,7 +67,7 @@ def test_prompt_from_template() -> None:
assert prompt == expected_prompt
def test_mustache_prompt_from_template() -> None:
def test_mustache_prompt_from_template(snapshot: SnapshotAssertion) -> None:
"""Test prompts can be constructed from a template."""
# Single input variable.
template = "This is a {{foo}} test."
@@ -110,24 +115,10 @@ def test_mustache_prompt_from_template() -> None:
"This foo is a bar test baz."
)
assert prompt.input_variables == ["foo", "obj"]
assert prompt.get_input_jsonschema() == {
"$defs": {
"obj": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"},
"foo": {"default": None, "title": "Foo", "type": "string"},
},
"title": "obj",
"type": "object",
}
},
"properties": {
"foo": {"default": None, "title": "Foo", "type": "string"},
"obj": {"allOf": [{"$ref": "#/$defs/obj"}], "default": None},
},
"title": "PromptInput",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_0"
)
# . variables
template = "This {{.}} is a test."
@@ -151,20 +142,10 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.get_input_jsonschema() == {
"$defs": {
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"}
},
"title": "foo",
"type": "object",
}
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_2"
)
# more complex nested section/context variables
template = """This{{#foo}}
@@ -185,29 +166,10 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.get_input_jsonschema() == {
"$defs": {
"baz": {
"properties": {
"qux": {"default": None, "title": "Qux", "type": "string"}
},
"title": "baz",
"type": "object",
},
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"},
"baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None},
"quux": {"default": None, "title": "Quux", "type": "string"},
},
"title": "foo",
"type": "object",
},
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_3"
)
# triply nested section/context variables
template = """This{{#foo}}
@@ -242,44 +204,10 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.get_input_jsonschema() == {
"$defs": {
"barfoo": {
"properties": {
"foobar": {"default": None, "title": "Foobar", "type": "string"}
},
"title": "barfoo",
"type": "object",
},
"baz": {
"properties": {
"qux": {"allOf": [{"$ref": "#/$defs/qux"}], "default": None}
},
"title": "baz",
"type": "object",
},
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"},
"baz": {"allOf": [{"$ref": "#/$defs/baz"}], "default": None},
"quux": {"default": None, "title": "Quux", "type": "string"},
},
"title": "foo",
"type": "object",
},
"qux": {
"properties": {
"barfoo": {"allOf": [{"$ref": "#/$defs/barfoo"}], "default": None},
"foobar": {"default": None, "title": "Foobar", "type": "string"},
},
"title": "qux",
"type": "object",
},
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_4"
)
# section/context variables with repeats
template = """This{{#foo}}
@@ -294,20 +222,10 @@ def test_mustache_prompt_from_template() -> None:
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.get_input_jsonschema() == {
"$defs": {
"foo": {
"properties": {
"bar": {"default": None, "title": "Bar", "type": "string"}
},
"title": "foo",
"type": "object",
}
},
"properties": {"foo": {"allOf": [{"$ref": "#/$defs/foo"}], "default": None}},
"title": "PromptInput",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(prompt.get_input_jsonschema()) == snapshot(
name="schema_5"
)
template = """This{{^foo}}
no foos
{{/foo}}is a test."""

View File

@@ -1,4 +1,6 @@
from typing import Any
from typing import Any, Dict
from pydantic import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -92,3 +94,29 @@ def _schema(obj: Any) -> dict:
_remove_enum_description(schema_)
return schema_
def _normalize_schema(obj: Any) -> Dict[str, Any]:
"""Generate a schema and normalize it.
This will collapse single element allOfs into $ref.
For example,
'obj': {'allOf': [{'$ref': '#/$defs/obj'}]
to:
'obj': {'$ref': '#/$defs/obj'}
Args:
obj: The object to generate the schema for
"""
if isinstance(obj, BaseModel):
data = obj.model_json_schema()
else:
data = obj
remove_all_none_default(data)
replace_all_of_with_ref(data)
_remove_enum_description(data)
return data

File diff suppressed because one or more lines are too long

View File

@@ -18,6 +18,7 @@ from typing import (
)
from uuid import UUID
import pydantic
import pytest
from freezegun import freeze_time
from pydantic import BaseModel, Field
@@ -90,9 +91,11 @@ from langchain_core.tracers import (
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.pydantic_utils import _schema
from tests.unit_tests.pydantic_utils import _normalize_schema, _schema
from tests.unit_tests.stubs import AnyStr, _AnyIdAIMessage, _AnyIdAIMessageChunk
PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution.
@@ -543,7 +546,7 @@ def test_passthrough_assign_schema() -> None:
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_lambda_schemas() -> None:
def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
first_lambda = lambda x: x["hello"] # noqa: E731
assert RunnableLambda(first_lambda).get_input_jsonschema() == {
"title": "RunnableLambdaInput",
@@ -617,45 +620,37 @@ def test_lambda_schemas() -> None:
}
assert (
RunnableLambda(
aget_values_typed # type: ignore[arg-type]
).get_input_jsonschema()
== {
"$defs": {
"InputType": {
"properties": {
"variable_name": {
"title": "Variable " "Name",
"type": "string",
_normalize_schema(
RunnableLambda(
aget_values_typed # type: ignore[arg-type]
).get_input_jsonschema()
)
== _normalize_schema(
{
"$defs": {
"InputType": {
"properties": {
"variable_name": {
"title": "Variable " "Name",
"type": "string",
},
"yo": {"title": "Yo", "type": "integer"},
},
"yo": {"title": "Yo", "type": "integer"},
},
"required": ["variable_name", "yo"],
"title": "InputType",
"type": "object",
}
},
"allOf": [{"$ref": "#/$defs/InputType"}],
"title": "aget_values_typed_input",
}
"required": ["variable_name", "yo"],
"title": "InputType",
"type": "object",
}
},
"allOf": [{"$ref": "#/$defs/InputType"}],
"title": "aget_values_typed_input",
}
)
)
assert RunnableLambda(aget_values_typed).get_output_jsonschema() == { # type: ignore[arg-type]
"$defs": {
"OutputType": {
"properties": {
"bye": {"title": "Bye", "type": "string"},
"byebye": {"title": "Byebye", "type": "integer"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["hello", "bye", "byebye"],
"title": "OutputType",
"type": "object",
}
},
"allOf": [{"$ref": "#/$defs/OutputType"}],
"title": "aget_values_typed_output",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(
RunnableLambda(aget_values_typed).get_output_jsonschema() # type: ignore
) == snapshot(name="schema8")
def test_with_types_with_type_generics() -> None:
@@ -752,7 +747,7 @@ def test_schema_complex_seq() -> None:
}
def test_configurable_fields() -> None:
def test_configurable_fields(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
assert fake_llm.invoke("...") == "a"
@@ -767,38 +762,10 @@ def test_configurable_fields() -> None:
assert fake_llm_configurable.invoke("...") == "a"
assert fake_llm_configurable.get_config_jsonschema() == {
"$defs": {
"Configurable": {
"properties": {
"llm_responses": {
"default": ["a"],
"description": "A "
"list "
"of "
"fake "
"responses "
"for "
"this "
"LLM",
"items": {"type": "string"},
"title": "LLM " "Responses",
"type": "array",
}
},
"title": "Configurable",
"type": "object",
}
},
"properties": {
"configurable": {
"allOf": [{"$ref": "#/$defs/Configurable"}],
"default": None,
}
},
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(
fake_llm_configurable.get_config_jsonschema()
) == snapshot(name="schema2")
fake_llm_configured = fake_llm_configurable.with_config(
configurable={"llm_responses": ["b"]}
@@ -822,35 +789,10 @@ def test_configurable_fields() -> None:
text="Hello, John!"
)
assert prompt_configurable.get_config_jsonschema() == {
"$defs": {
"Configurable": {
"properties": {
"prompt_template": {
"default": "Hello, " "{name}!",
"description": "The "
"prompt "
"template "
"for "
"this "
"chain",
"title": "Prompt " "Template",
"type": "string",
}
},
"title": "Configurable",
"type": "object",
}
},
"properties": {
"configurable": {
"allOf": [{"$ref": "#/$defs/Configurable"}],
"default": None,
}
},
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(
prompt_configurable.get_config_jsonschema()
) == snapshot(name="schema3")
prompt_configured = prompt_configurable.with_config(
configurable={"prompt_template": "Hello, {name}! {name}!"}
@@ -876,49 +818,10 @@ def test_configurable_fields() -> None:
assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.get_config_jsonschema() == {
"$defs": {
"Configurable": {
"properties": {
"llm_responses": {
"default": ["a"],
"description": "A "
"list "
"of "
"fake "
"responses "
"for "
"this "
"LLM",
"items": {"type": "string"},
"title": "LLM " "Responses",
"type": "array",
},
"prompt_template": {
"default": "Hello, " "{name}!",
"description": "The "
"prompt "
"template "
"for "
"this "
"chain",
"title": "Prompt " "Template",
"type": "string",
},
},
"title": "Configurable",
"type": "object",
}
},
"properties": {
"configurable": {
"allOf": [{"$ref": "#/$defs/Configurable"}],
"default": None,
}
},
"title": "RunnableSequenceConfig",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(
chain_configurable.get_config_jsonschema()
) == snapshot(name="schema4")
assert (
chain_configurable.with_config(
@@ -960,55 +863,10 @@ def test_configurable_fields() -> None:
"llm3": "a",
}
assert chain_with_map_configurable.get_config_jsonschema() == {
"$defs": {
"Configurable": {
"properties": {
"llm_responses": {
"default": ["a"],
"description": "A "
"list "
"of "
"fake "
"responses "
"for "
"this "
"LLM",
"items": {"type": "string"},
"title": "LLM " "Responses",
"type": "array",
},
"other_responses": {
"default": ["a"],
"items": {"type": "string"},
"title": "Other " "Responses",
"type": "array",
},
"prompt_template": {
"default": "Hello, " "{name}!",
"description": "The "
"prompt "
"template "
"for "
"this "
"chain",
"title": "Prompt " "Template",
"type": "string",
},
},
"title": "Configurable",
"type": "object",
}
},
"properties": {
"configurable": {
"allOf": [{"$ref": "#/$defs/Configurable"}],
"default": None,
}
},
"title": "RunnableSequenceConfig",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(
chain_with_map_configurable.get_config_jsonschema()
) == snapshot(name="schema5")
assert chain_with_map_configurable.with_config(
configurable={
@@ -1030,7 +888,7 @@ def test_configurable_alts_factory() -> None:
assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"
def test_configurable_fields_prefix_keys() -> None:
def test_configurable_fields_prefix_keys(snapshot: SnapshotAssertion) -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(
id="responses",
@@ -1078,74 +936,13 @@ def test_configurable_fields_prefix_keys() -> None:
chain = prompt | fake_llm
assert _schema(chain.config_schema()) == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"Chat_Responses": {
"enum": ["hello", "bye", "helpful"],
"title": "Chat Responses",
"type": "string",
},
"Configurable": {
"properties": {
"chat_sleep": {
"anyOf": [{"type": "number"}, {"type": "null"}],
"default": None,
"title": "Chat " "Sleep",
},
"llm": {
"$ref": "#/definitions/LLM",
"default": "default",
"title": "LLM",
},
"llm==chat/responses": {
"default": ["hello", "bye"],
"items": {"$ref": "#/definitions/Chat_Responses"},
"title": "Chat " "Responses",
"type": "array",
},
"llm==default/responses": {
"default": ["a"],
"description": "A "
"list "
"of "
"fake "
"responses "
"for "
"this "
"LLM",
"items": {"type": "string"},
"title": "LLM " "Responses",
"type": "array",
},
"prompt_template": {
"$ref": "#/definitions/Prompt_Template",
"default": "hello",
"description": "The "
"prompt "
"template "
"for "
"this "
"chain",
"title": "Prompt " "Template",
},
},
"title": "Configurable",
"type": "object",
},
"LLM": {"enum": ["chat", "default"], "title": "LLM", "type": "string"},
"Prompt_Template": {
"enum": ["hello", "good_morning"],
"title": "Prompt Template",
"type": "string",
},
},
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(_schema(chain.config_schema())) == snapshot(
name="schema6"
)
def test_configurable_fields_example() -> None:
def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(
id="chat_responses",
@@ -1191,71 +988,10 @@ def test_configurable_fields_example() -> None:
assert chain_configurable.invoke({"name": "John"}) == "a"
assert chain_configurable.get_config_jsonschema() == {
"$defs": {
"Chat_Responses": {
"enum": ["hello", "bye", "helpful"],
"title": "Chat Responses",
"type": "string",
},
"Configurable": {
"properties": {
"chat_responses": {
"default": ["hello", "bye"],
"items": {"$ref": "#/$defs/Chat_Responses"},
"title": "Chat " "Responses",
"type": "array",
},
"llm": {
"allOf": [{"$ref": "#/$defs/LLM"}],
"default": "default",
"title": "LLM",
},
"llm_responses": {
"default": ["a"],
"description": "A "
"list "
"of "
"fake "
"responses "
"for "
"this "
"LLM",
"items": {"type": "string"},
"title": "LLM " "Responses",
"type": "array",
},
"prompt_template": {
"allOf": [{"$ref": "#/$defs/Prompt_Template"}],
"default": "hello",
"description": "The "
"prompt "
"template "
"for "
"this "
"chain",
"title": "Prompt " "Template",
},
},
"title": "Configurable",
"type": "object",
},
"LLM": {"enum": ["chat", "default"], "title": "LLM", "type": "string"},
"Prompt_Template": {
"enum": ["hello", "good_morning"],
"title": "Prompt Template",
"type": "string",
},
},
"properties": {
"configurable": {
"allOf": [{"$ref": "#/$defs/Configurable"}],
"default": None,
}
},
"title": "RunnableSequenceConfig",
"type": "object",
}
if PYDANTIC_VERSION >= (2, 9):
assert _normalize_schema(
chain_configurable.get_config_jsonschema()
) == snapshot(name="schema7")
assert (
chain_configurable.with_config(configurable={"llm": "chat"}).invoke(