core: Removing unnecessary pydantic core schema rebuilds (#30848)

We only need to rebuild model schemas if type annotation information
isn't available during declaration - that shouldn't be the case for
these types corrected here.

Need to do more thorough testing to make sure these structures have
complete schemas, but hopefully this boosts startup / import time.
This commit is contained in:
Sydney Runkle 2025-04-16 12:00:08 -04:00 committed by GitHub
parent 60d8ade078
commit 88fce67724
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 26 additions and 63 deletions

View File

@ -276,9 +276,6 @@ class AIMessage(BaseMessage):
return (base.strip() + "\n" + "\n".join(lines)).strip() return (base.strip() + "\n" + "\n".join(lines)).strip()
AIMessage.model_rebuild()
class AIMessageChunk(AIMessage, BaseMessageChunk): class AIMessageChunk(AIMessage, BaseMessageChunk):
"""Message chunk from an AI.""" """Message chunk from an AI."""

View File

@ -22,9 +22,6 @@ class ChatMessage(BaseMessage):
"""The type of the message (used during serialization). Defaults to "chat".""" """The type of the message (used during serialization). Defaults to "chat"."""
ChatMessage.model_rebuild()
class ChatMessageChunk(ChatMessage, BaseMessageChunk): class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""Chat Message chunk.""" """Chat Message chunk."""

View File

@ -30,9 +30,6 @@ class FunctionMessage(BaseMessage):
"""The type of the message (used for serialization). Defaults to "function".""" """The type of the message (used for serialization). Defaults to "function"."""
FunctionMessage.model_rebuild()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""Function Message chunk.""" """Function Message chunk."""

View File

@ -52,9 +52,6 @@ class HumanMessage(BaseMessage):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
HumanMessage.model_rebuild()
class HumanMessageChunk(HumanMessage, BaseMessageChunk): class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""Human Message chunk.""" """Human Message chunk."""

View File

@ -26,6 +26,3 @@ class RemoveMessage(BaseMessage):
raise ValueError(msg) raise ValueError(msg)
super().__init__("", id=id, **kwargs) super().__init__("", id=id, **kwargs)
RemoveMessage.model_rebuild()

View File

@ -46,9 +46,6 @@ class SystemMessage(BaseMessage):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
SystemMessage.model_rebuild()
class SystemMessageChunk(SystemMessage, BaseMessageChunk): class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""System Message chunk.""" """System Message chunk."""

View File

@ -146,9 +146,6 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
super().__init__(content=content, **kwargs) super().__init__(content=content, **kwargs)
ToolMessage.model_rebuild()
class ToolMessageChunk(ToolMessage, BaseMessageChunk): class ToolMessageChunk(ToolMessage, BaseMessageChunk):
"""Tool Message chunk.""" """Tool Message chunk."""

View File

@ -133,9 +133,6 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
yield [part] yield [part]
ListOutputParser.model_rebuild()
class CommaSeparatedListOutputParser(ListOutputParser): class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse the output of an LLM call to a comma-separated list.""" """Parse the output of an LLM call to a comma-separated list."""

View File

@ -114,9 +114,6 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
return self.pydantic_object return self.pydantic_object
PydanticOutputParser.model_rebuild()
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below. _PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}} As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}

View File

@ -31,6 +31,3 @@ class StrOutputParser(BaseTransformOutputParser[str]):
def parse(self, text: str) -> str: def parse(self, text: str) -> str:
"""Returns the input text with no changes.""" """Returns the input text with no changes."""
return text return text
StrOutputParser.model_rebuild()

View File

@ -132,6 +132,3 @@ class PipelinePromptTemplate(BasePromptTemplate):
@property @property
def _prompt_type(self) -> str: def _prompt_type(self) -> str:
raise ValueError raise ValueError
PipelinePromptTemplate.model_rebuild()

View File

@ -5650,9 +5650,6 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item yield item
RunnableBindingBase.model_rebuild()
class RunnableBinding(RunnableBindingBase[Input, Output]): class RunnableBinding(RunnableBindingBase[Input, Output]):
"""Wrap a Runnable with additional functionality. """Wrap a Runnable with additional functionality.

View File

@ -8,7 +8,6 @@ from abc import abstractmethod
from collections.abc import ( from collections.abc import (
AsyncIterator, AsyncIterator,
Iterator, Iterator,
Mapping, # noqa: F401 Needed by pydantic
Sequence, Sequence,
) )
from functools import wraps from functools import wraps
@ -464,9 +463,6 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
return (self.default, config) return (self.default, config)
RunnableConfigurableFields.model_rebuild()
# Before Python 3.11 native StrEnum is not available # Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum): class StrEnum(str, enum.Enum):
"""String enum.""" """String enum."""

View File

@ -6,6 +6,7 @@ import ast
import asyncio import asyncio
import inspect import inspect
import textwrap import textwrap
from collections.abc import Mapping, Sequence
from contextvars import Context from contextvars import Context
from functools import lru_cache from functools import lru_cache
from inspect import signature from inspect import signature
@ -33,8 +34,6 @@ if TYPE_CHECKING:
Awaitable, Awaitable,
Coroutine, Coroutine,
Iterable, Iterable,
Mapping,
Sequence,
) )
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent

View File

@ -176,6 +176,3 @@ class Tool(BaseTool):
args_schema=args_schema, args_schema=args_schema,
**kwargs, **kwargs,
) )
Tool.model_rebuild()

View File

@ -227,9 +227,6 @@ class SerializableModel(GenericFakeChatModel):
return True return True
SerializableModel.model_rebuild()
def test_serialization_with_rate_limiter() -> None: def test_serialization_with_rate_limiter() -> None:
"""Test model serialization with rate limiter.""" """Test model serialization with rate limiter."""
from langchain_core.load import dumps from langchain_core.load import dumps

View File

@ -45,8 +45,6 @@ def test_base_generation_parser() -> None:
assert isinstance(content, str) assert isinstance(content, str)
return content.swapcase() return content.swapcase()
StrInvertCase.model_rebuild()
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")])) model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
chain = model | StrInvertCase() chain = model | StrInvertCase()
assert chain.invoke("") == "HeLLO" assert chain.invoke("") == "HeLLO"

View File

@ -35,9 +35,6 @@ class FakeStructuredChatModel(FakeListChatModel):
return "fake-messages-list-chat-model" return "fake-messages-list-chat-model"
FakeStructuredChatModel.model_rebuild()
def test_structured_prompt_pydantic() -> None: def test_structured_prompt_pydantic() -> None:
class OutputSchema(BaseModel): class OutputSchema(BaseModel):
name: str name: str

View File

@ -1188,9 +1188,6 @@ class HardCodedRetriever(BaseRetriever):
return self.documents return self.documents
HardCodedRetriever.model_rebuild()
async def test_event_stream_with_retriever() -> None: async def test_event_stream_with_retriever() -> None:
"""Test the event stream with a retriever.""" """Test the event stream with a retriever."""
retriever = HardCodedRetriever( retriever = HardCodedRetriever(

View File

@ -0,0 +1,20 @@
import importlib
from pathlib import Path
from pydantic import BaseModel
def test_all_models_built() -> None:
for path in Path("../core/langchain_core/").glob("*"):
module_name = path.stem
if not module_name.startswith(".") and path.suffix != ".typed":
module = importlib.import_module("langchain_core." + module_name)
all_ = getattr(module, "__all__", [])
for attr_name in all_:
attr = getattr(module, attr_name)
try:
if issubclass(attr, BaseModel):
assert attr.__pydantic_complete__ is True
except TypeError:
# This is expected for non-class attributes
pass

View File

@ -1091,9 +1091,6 @@ class FooBase(BaseTool):
return assert_bar(bar, bar_config) return assert_bar(bar, bar_config)
FooBase.model_rebuild()
class AFooBase(FooBase): class AFooBase(FooBase):
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any: async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
return assert_bar(bar, bar_config) return assert_bar(bar, bar_config)

View File

@ -2338,7 +2338,7 @@ dependencies = [
requires-dist = [ requires-dist = [
{ name = "chromadb", specifier = ">=0.4.0,!=0.5.4,!=0.5.5,!=0.5.7,!=0.5.9,!=0.5.10,!=0.5.11,!=0.5.12,<0.7.0" }, { name = "chromadb", specifier = ">=0.4.0,!=0.5.4,!=0.5.5,!=0.5.7,!=0.5.9,!=0.5.10,!=0.5.11,!=0.5.12,<0.7.0" },
{ name = "langchain-core", editable = "libs/core" }, { name = "langchain-core", editable = "libs/core" },
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.22.4" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.0" },
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" }, { name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
] ]
@ -2451,7 +2451,7 @@ typing = [
{ name = "langchain", editable = "libs/langchain" }, { name = "langchain", editable = "libs/langchain" },
{ name = "langchain-core", editable = "libs/core" }, { name = "langchain-core", editable = "libs/core" },
{ name = "langchain-text-splitters", editable = "libs/text-splitters" }, { name = "langchain-text-splitters", editable = "libs/text-splitters" },
{ name = "mypy", specifier = ">=1.12,<2.0" }, { name = "mypy", specifier = ">=1.15,<2.0" },
{ name = "mypy-protobuf", specifier = ">=3.0.0,<4.0.0" }, { name = "mypy-protobuf", specifier = ">=3.0.0,<4.0.0" },
{ name = "types-chardet", specifier = ">=5.0.4.6,<6.0.0.0" }, { name = "types-chardet", specifier = ">=5.0.4.6,<6.0.0.0" },
{ name = "types-pytz", specifier = ">=2023.3.0.0,<2024.0.0.0" }, { name = "types-pytz", specifier = ">=2023.3.0.0,<2024.0.0.0" },
@ -2503,6 +2503,8 @@ test = [
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" }, { name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
{ name = "pytest", specifier = ">=8,<9" }, { name = "pytest", specifier = ">=8,<9" },
{ name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" }, { name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" },
{ name = "pytest-benchmark" },
{ name = "pytest-codspeed" },
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" }, { name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
{ name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" }, { name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" },
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" }, { name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
@ -2513,8 +2515,7 @@ test = [
test-integration = [] test-integration = []
typing = [ typing = [
{ name = "langchain-text-splitters", directory = "libs/text-splitters" }, { name = "langchain-text-splitters", directory = "libs/text-splitters" },
{ name = "mypy", specifier = ">=1.10,<1.11" }, { name = "mypy", specifier = ">=1.15,<1.16" },
{ name = "types-jinja2", specifier = ">=2.11.9,<3.0.0" },
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" }, { name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" }, { name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
] ]