from collections import OrderedDict
from typing import Any, List, Optional, Tuple

import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten


# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Any]:
    return list(d.values()), list(d.keys())


def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]":
    return OrderedDict((key, value) for key, value in zip(context, values))


_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)


def tree_map_hf(fn: Any, pytree: Any):
    flat_args, spec = tree_flatten_hf(pytree)
    return tree_unflatten([fn(i) for i in flat_args], spec)


# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
    """Flattens a pytree into a list of values an a TreeSpec that can be used
    to reconstruct the pytree.
    """
    if isinstance(pytree, OrderedDict):
        node_type = OrderedDict
        flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
        child_pytrees, context = flatten_fn(pytree)

        # Recursively flatten the children
        result: List[Any] = []
        children_specs: List["TreeSpec"] = []
        for child in child_pytrees:
            flat, child_spec = tree_flatten_hf(child)
            result += flat
            children_specs.append(child_spec)
        return result, TreeSpec(node_type, context, children_specs)
    else:
        result, tree_spec = tree_flatten(pytree)
        return result, tree_spec


def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
    """Move object to device if it is a tensor.

    Args:
        x (Any): Object to be moved.
        device (Optional[torch.device], optional): Target device. Defaults to None.

    Returns:
        Any: Moved object.
    """
    if isinstance(x, torch.Tensor):
        return x.to(device)
    return x


def get_batch_size(batch: Any) -> int:
    """Get the batch size (size of dimension-0) of the first tensor in the batch.

    Args:
        batch (Any): Batch to be inspected.

    Raises:
        RuntimeError: If no tensor is found in the batch.

    Returns:
        int: Batch size.
    """
    data_list, _ = tree_flatten(batch)
    for data in data_list:
        if isinstance(data, torch.Tensor):
            return data.size(0)
    raise RuntimeError("No tensor found in the batch")


def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
    """Get a micro batch of the original batch.

    Args:
        batch (Any): Batch to be sliced.
        start (int): Start index of the micro batch.
        micro_batch_size (int): Size of the micro batch.

    Returns:
        Any: Target micro batch.
    """

    def _get_tensor_slice(x: Any):
        if isinstance(x, torch.Tensor):
            return x[start : start + micro_batch_size]
        return x

    return tree_map(_get_tensor_slice, batch)


def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any:
    """Call model forward function with data and internal inputs.

    Args:
        model (Module): Model to be called.
        data (Any): Data loaded from data iterator.
        internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage.

    Returns:
        Any: Outputs of the model.
    """
    if internal_inputs is None:
        internal_inputs = {}
    if isinstance(data, (list, tuple)):
        return model(*data, **internal_inputs)
    elif isinstance(data, dict):
        return model(**data, **internal_inputs)
    return model(data, **internal_inputs)


def retain_grad(x: Any) -> None:
    """Call retain_grad() on a tensor.

    Args:
        x (Any): Object to be called.
    """
    if isinstance(x, torch.Tensor) and x.requires_grad:
        x.retain_grad()


def detach(x: Any) -> Any:
    """Call detach() on a tensor.

    Args:
        x (Any): Object to be called.

    Returns:
        Any: The detached object.
    """
    if isinstance(x, torch.Tensor):
        return x.detach()
    return x


def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
    """Merge micro batches into a batch.

    Args:
        data (List[Any]): A list of micro batches.

    Returns:
        Any: Merge batch.
    """
    if len(data) == 0:
        return
    flattened_data = []
    tree_spec = None
    for d in data:
        # elems should be an instance of OrderedDict
        elems, tree_spec = tree_flatten_hf(d)
        flattened_data.append(elems)
    merged_data = []

    for elem_batch in zip(*flattened_data):
        if isinstance(elem_batch[0], torch.Tensor):
            if len(elem_batch[0].shape) == 0:  # set loss to None in pipeline outputs
                merged_data.append(None)
            else:
                merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
        else:
            merged_data.append(list(elem_batch))
    return tree_unflatten(merged_data, tree_spec)