rework import ordering

This commit is contained in:
Mason Daugherty 2025-08-06 17:24:00 -04:00
parent ad8db57efa
commit 0e6f3ecad0
No known key found for this signature in database
2 changed files with 5 additions and 7 deletions

View File

@ -1,11 +1,14 @@
"""Dump objects to json.""" """Dump objects to json."""
import dataclasses
import inspect
import json import json
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from langchain_core.load.serializable import Serializable, to_json_not_implemented from langchain_core.load.serializable import Serializable, to_json_not_implemented
from langchain_core.v1.messages import MessageV1Types
def default(obj: Any) -> Any: def default(obj: Any) -> Any:
@ -21,17 +24,12 @@ def default(obj: Any) -> Any:
return obj.to_json() return obj.to_json()
# Handle v1 message classes # Handle v1 message classes
from langchain_core.v1.messages import MessageV1Types
if type(obj) in MessageV1Types: if type(obj) in MessageV1Types:
import dataclasses
import inspect
# Get the constructor signature to only include valid parameters # Get the constructor signature to only include valid parameters
init_sig = inspect.signature(type(obj).__init__) init_sig = inspect.signature(type(obj).__init__)
valid_params = set(init_sig.parameters.keys()) - {"self"} valid_params = set(init_sig.parameters.keys()) - {"self"}
# Filter the dataclass fields to only include constructor parameters # Filter dataclass fields to only include constructor params
all_fields = dataclasses.asdict(obj) all_fields = dataclasses.asdict(obj)
kwargs = {k: v for k, v in all_fields.items() if k in valid_params} kwargs = {k: v for k, v in all_fields.items() if k in valid_params}

View File

@ -156,9 +156,9 @@ class Reviver:
cls = getattr(mod, name) cls = getattr(mod, name)
# The class must be a subclass of Serializable or a v1 message class.
from langchain_core.v1.messages import MessageV1Types from langchain_core.v1.messages import MessageV1Types
# The class must be a subclass of Serializable or a v1 message class.
if not (issubclass(cls, Serializable) or cls in MessageV1Types): if not (issubclass(cls, Serializable) or cls in MessageV1Types):
msg = f"Invalid namespace: {value}" msg = f"Invalid namespace: {value}"
raise ValueError(msg) raise ValueError(msg)