mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
rework import ordering
This commit is contained in:
parent
ad8db57efa
commit
0e6f3ecad0
@ -1,11 +1,14 @@
|
|||||||
"""Dump objects to json."""
|
"""Dump objects to json."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
from langchain_core.load.serializable import Serializable, to_json_not_implemented
|
||||||
|
from langchain_core.v1.messages import MessageV1Types
|
||||||
|
|
||||||
|
|
||||||
def default(obj: Any) -> Any:
|
def default(obj: Any) -> Any:
|
||||||
@ -21,17 +24,12 @@ def default(obj: Any) -> Any:
|
|||||||
return obj.to_json()
|
return obj.to_json()
|
||||||
|
|
||||||
# Handle v1 message classes
|
# Handle v1 message classes
|
||||||
from langchain_core.v1.messages import MessageV1Types
|
|
||||||
|
|
||||||
if type(obj) in MessageV1Types:
|
if type(obj) in MessageV1Types:
|
||||||
import dataclasses
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
# Get the constructor signature to only include valid parameters
|
# Get the constructor signature to only include valid parameters
|
||||||
init_sig = inspect.signature(type(obj).__init__)
|
init_sig = inspect.signature(type(obj).__init__)
|
||||||
valid_params = set(init_sig.parameters.keys()) - {"self"}
|
valid_params = set(init_sig.parameters.keys()) - {"self"}
|
||||||
|
|
||||||
# Filter the dataclass fields to only include constructor parameters
|
# Filter dataclass fields to only include constructor params
|
||||||
all_fields = dataclasses.asdict(obj)
|
all_fields = dataclasses.asdict(obj)
|
||||||
kwargs = {k: v for k, v in all_fields.items() if k in valid_params}
|
kwargs = {k: v for k, v in all_fields.items() if k in valid_params}
|
||||||
|
|
||||||
|
@ -156,9 +156,9 @@ class Reviver:
|
|||||||
|
|
||||||
cls = getattr(mod, name)
|
cls = getattr(mod, name)
|
||||||
|
|
||||||
# The class must be a subclass of Serializable or a v1 message class.
|
|
||||||
from langchain_core.v1.messages import MessageV1Types
|
from langchain_core.v1.messages import MessageV1Types
|
||||||
|
|
||||||
|
# The class must be a subclass of Serializable or a v1 message class.
|
||||||
if not (issubclass(cls, Serializable) or cls in MessageV1Types):
|
if not (issubclass(cls, Serializable) or cls in MessageV1Types):
|
||||||
msg = f"Invalid namespace: {value}"
|
msg = f"Invalid namespace: {value}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
Loading…
Reference in New Issue
Block a user