mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
core: Removing unnecessary pydantic
core schema rebuilds (#30848)
We only need to rebuild model schemas if type annotation information isn't available during declaration - that shouldn't be the case for these types corrected here. Need to do more thorough testing to make sure these structures have complete schemas, but hopefully this boosts startup / import time.
This commit is contained in:
parent
60d8ade078
commit
88fce67724
@ -276,9 +276,6 @@ class AIMessage(BaseMessage):
|
|||||||
return (base.strip() + "\n" + "\n".join(lines)).strip()
|
return (base.strip() + "\n" + "\n".join(lines)).strip()
|
||||||
|
|
||||||
|
|
||||||
AIMessage.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||||
"""Message chunk from an AI."""
|
"""Message chunk from an AI."""
|
||||||
|
|
||||||
|
@ -22,9 +22,6 @@ class ChatMessage(BaseMessage):
|
|||||||
"""The type of the message (used during serialization). Defaults to "chat"."""
|
"""The type of the message (used during serialization). Defaults to "chat"."""
|
||||||
|
|
||||||
|
|
||||||
ChatMessage.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||||
"""Chat Message chunk."""
|
"""Chat Message chunk."""
|
||||||
|
|
||||||
|
@ -30,9 +30,6 @@ class FunctionMessage(BaseMessage):
|
|||||||
"""The type of the message (used for serialization). Defaults to "function"."""
|
"""The type of the message (used for serialization). Defaults to "function"."""
|
||||||
|
|
||||||
|
|
||||||
FunctionMessage.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||||
"""Function Message chunk."""
|
"""Function Message chunk."""
|
||||||
|
|
||||||
|
@ -52,9 +52,6 @@ class HumanMessage(BaseMessage):
|
|||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
HumanMessage.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||||
"""Human Message chunk."""
|
"""Human Message chunk."""
|
||||||
|
|
||||||
|
@ -26,6 +26,3 @@ class RemoveMessage(BaseMessage):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
super().__init__("", id=id, **kwargs)
|
super().__init__("", id=id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
RemoveMessage.model_rebuild()
|
|
||||||
|
@ -46,9 +46,6 @@ class SystemMessage(BaseMessage):
|
|||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
SystemMessage.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||||
"""System Message chunk."""
|
"""System Message chunk."""
|
||||||
|
|
||||||
|
@ -146,9 +146,6 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
|
|||||||
super().__init__(content=content, **kwargs)
|
super().__init__(content=content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
ToolMessage.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||||
"""Tool Message chunk."""
|
"""Tool Message chunk."""
|
||||||
|
|
||||||
|
@ -133,9 +133,6 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
|
|||||||
yield [part]
|
yield [part]
|
||||||
|
|
||||||
|
|
||||||
ListOutputParser.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class CommaSeparatedListOutputParser(ListOutputParser):
|
class CommaSeparatedListOutputParser(ListOutputParser):
|
||||||
"""Parse the output of an LLM call to a comma-separated list."""
|
"""Parse the output of an LLM call to a comma-separated list."""
|
||||||
|
|
||||||
|
@ -114,9 +114,6 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
|
|||||||
return self.pydantic_object
|
return self.pydantic_object
|
||||||
|
|
||||||
|
|
||||||
PydanticOutputParser.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||||
|
|
||||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||||
|
@ -31,6 +31,3 @@ class StrOutputParser(BaseTransformOutputParser[str]):
|
|||||||
def parse(self, text: str) -> str:
|
def parse(self, text: str) -> str:
|
||||||
"""Returns the input text with no changes."""
|
"""Returns the input text with no changes."""
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
StrOutputParser.model_rebuild()
|
|
||||||
|
@ -132,6 +132,3 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
|||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
PipelinePromptTemplate.model_rebuild()
|
|
||||||
|
@ -5650,9 +5650,6 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
RunnableBindingBase.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||||
"""Wrap a Runnable with additional functionality.
|
"""Wrap a Runnable with additional functionality.
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ from abc import abstractmethod
|
|||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Iterator,
|
Iterator,
|
||||||
Mapping, # noqa: F401 Needed by pydantic
|
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@ -464,9 +463,6 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|||||||
return (self.default, config)
|
return (self.default, config)
|
||||||
|
|
||||||
|
|
||||||
RunnableConfigurableFields.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
# Before Python 3.11 native StrEnum is not available
|
# Before Python 3.11 native StrEnum is not available
|
||||||
class StrEnum(str, enum.Enum):
|
class StrEnum(str, enum.Enum):
|
||||||
"""String enum."""
|
"""String enum."""
|
||||||
|
@ -6,6 +6,7 @@ import ast
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from contextvars import Context
|
from contextvars import Context
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
@ -33,8 +34,6 @@ if TYPE_CHECKING:
|
|||||||
Awaitable,
|
Awaitable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Iterable,
|
Iterable,
|
||||||
Mapping,
|
|
||||||
Sequence,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
|
@ -176,6 +176,3 @@ class Tool(BaseTool):
|
|||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Tool.model_rebuild()
|
|
||||||
|
@ -227,9 +227,6 @@ class SerializableModel(GenericFakeChatModel):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
SerializableModel.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialization_with_rate_limiter() -> None:
|
def test_serialization_with_rate_limiter() -> None:
|
||||||
"""Test model serialization with rate limiter."""
|
"""Test model serialization with rate limiter."""
|
||||||
from langchain_core.load import dumps
|
from langchain_core.load import dumps
|
||||||
|
@ -45,8 +45,6 @@ def test_base_generation_parser() -> None:
|
|||||||
assert isinstance(content, str)
|
assert isinstance(content, str)
|
||||||
return content.swapcase()
|
return content.swapcase()
|
||||||
|
|
||||||
StrInvertCase.model_rebuild()
|
|
||||||
|
|
||||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
|
model = GenericFakeChatModel(messages=iter([AIMessage(content="hEllo")]))
|
||||||
chain = model | StrInvertCase()
|
chain = model | StrInvertCase()
|
||||||
assert chain.invoke("") == "HeLLO"
|
assert chain.invoke("") == "HeLLO"
|
||||||
|
@ -35,9 +35,6 @@ class FakeStructuredChatModel(FakeListChatModel):
|
|||||||
return "fake-messages-list-chat-model"
|
return "fake-messages-list-chat-model"
|
||||||
|
|
||||||
|
|
||||||
FakeStructuredChatModel.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
def test_structured_prompt_pydantic() -> None:
|
def test_structured_prompt_pydantic() -> None:
|
||||||
class OutputSchema(BaseModel):
|
class OutputSchema(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@ -1188,9 +1188,6 @@ class HardCodedRetriever(BaseRetriever):
|
|||||||
return self.documents
|
return self.documents
|
||||||
|
|
||||||
|
|
||||||
HardCodedRetriever.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_event_stream_with_retriever() -> None:
|
async def test_event_stream_with_retriever() -> None:
|
||||||
"""Test the event stream with a retriever."""
|
"""Test the event stream with a retriever."""
|
||||||
retriever = HardCodedRetriever(
|
retriever = HardCodedRetriever(
|
||||||
|
20
libs/core/tests/unit_tests/test_pydantic_imports.py
Normal file
20
libs/core/tests/unit_tests/test_pydantic_imports.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_models_built() -> None:
|
||||||
|
for path in Path("../core/langchain_core/").glob("*"):
|
||||||
|
module_name = path.stem
|
||||||
|
if not module_name.startswith(".") and path.suffix != ".typed":
|
||||||
|
module = importlib.import_module("langchain_core." + module_name)
|
||||||
|
all_ = getattr(module, "__all__", [])
|
||||||
|
for attr_name in all_:
|
||||||
|
attr = getattr(module, attr_name)
|
||||||
|
try:
|
||||||
|
if issubclass(attr, BaseModel):
|
||||||
|
assert attr.__pydantic_complete__ is True
|
||||||
|
except TypeError:
|
||||||
|
# This is expected for non-class attributes
|
||||||
|
pass
|
@ -1091,9 +1091,6 @@ class FooBase(BaseTool):
|
|||||||
return assert_bar(bar, bar_config)
|
return assert_bar(bar, bar_config)
|
||||||
|
|
||||||
|
|
||||||
FooBase.model_rebuild()
|
|
||||||
|
|
||||||
|
|
||||||
class AFooBase(FooBase):
|
class AFooBase(FooBase):
|
||||||
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
async def _arun(self, bar: Any, bar_config: RunnableConfig, **kwargs: Any) -> Any:
|
||||||
return assert_bar(bar, bar_config)
|
return assert_bar(bar, bar_config)
|
||||||
|
9
uv.lock
9
uv.lock
@ -2338,7 +2338,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "chromadb", specifier = ">=0.4.0,!=0.5.4,!=0.5.5,!=0.5.7,!=0.5.9,!=0.5.10,!=0.5.11,!=0.5.12,<0.7.0" },
|
{ name = "chromadb", specifier = ">=0.4.0,!=0.5.4,!=0.5.5,!=0.5.7,!=0.5.9,!=0.5.10,!=0.5.11,!=0.5.12,<0.7.0" },
|
||||||
{ name = "langchain-core", editable = "libs/core" },
|
{ name = "langchain-core", editable = "libs/core" },
|
||||||
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.22.4" },
|
{ name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.0" },
|
||||||
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2451,7 +2451,7 @@ typing = [
|
|||||||
{ name = "langchain", editable = "libs/langchain" },
|
{ name = "langchain", editable = "libs/langchain" },
|
||||||
{ name = "langchain-core", editable = "libs/core" },
|
{ name = "langchain-core", editable = "libs/core" },
|
||||||
{ name = "langchain-text-splitters", editable = "libs/text-splitters" },
|
{ name = "langchain-text-splitters", editable = "libs/text-splitters" },
|
||||||
{ name = "mypy", specifier = ">=1.12,<2.0" },
|
{ name = "mypy", specifier = ">=1.15,<2.0" },
|
||||||
{ name = "mypy-protobuf", specifier = ">=3.0.0,<4.0.0" },
|
{ name = "mypy-protobuf", specifier = ">=3.0.0,<4.0.0" },
|
||||||
{ name = "types-chardet", specifier = ">=5.0.4.6,<6.0.0.0" },
|
{ name = "types-chardet", specifier = ">=5.0.4.6,<6.0.0.0" },
|
||||||
{ name = "types-pytz", specifier = ">=2023.3.0.0,<2024.0.0.0" },
|
{ name = "types-pytz", specifier = ">=2023.3.0.0,<2024.0.0.0" },
|
||||||
@ -2503,6 +2503,8 @@ test = [
|
|||||||
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
{ name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.1.0" },
|
||||||
{ name = "pytest", specifier = ">=8,<9" },
|
{ name = "pytest", specifier = ">=8,<9" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" },
|
{ name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" },
|
||||||
|
{ name = "pytest-benchmark" },
|
||||||
|
{ name = "pytest-codspeed" },
|
||||||
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
|
{ name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" },
|
||||||
{ name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" },
|
{ name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" },
|
||||||
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
|
{ name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" },
|
||||||
@ -2513,8 +2515,7 @@ test = [
|
|||||||
test-integration = []
|
test-integration = []
|
||||||
typing = [
|
typing = [
|
||||||
{ name = "langchain-text-splitters", directory = "libs/text-splitters" },
|
{ name = "langchain-text-splitters", directory = "libs/text-splitters" },
|
||||||
{ name = "mypy", specifier = ">=1.10,<1.11" },
|
{ name = "mypy", specifier = ">=1.15,<1.16" },
|
||||||
{ name = "types-jinja2", specifier = ">=2.11.9,<3.0.0" },
|
|
||||||
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
|
{ name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" },
|
||||||
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
|
{ name = "types-requests", specifier = ">=2.28.11.5,<3.0.0.0" },
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user