#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter

from colossalai.legacy.context import seed
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
from colossalai.legacy.registry import LAYERS
from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax


@LAYERS.register_module
class TransformerSelfAttentionRing(nn.Module):
    """Parallel self-attention layer abstract class.
    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.

    Args:
        hidden_size (int): hidden size.
        num_attention_heads (int): number of attention heads.
        attention_dropout (float): dropout probability for attention layer.
        attention_mask_func (:class:`typing.Callable`): Mask function to be applied.
        layer_number (int): number of layers.

    """

    def __init__(
        self,
        hidden_size,
        num_attention_heads,
        attention_dropout,
        attention_mask_func,
        layer_number,
        apply_query_key_layer_scaling: bool = False,
        convert_fp16_to_fp32_in_softmax: bool = False,
        attn_mask_type=AttnMaskType.padding,
        masked_softmax_fusion=True,
        fp16=False,
        bf16=False,
    ):
        super().__init__()
        self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.attention_mask_func = attention_mask_func
        self.layer_number = layer_number
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.attn_mask_type = attn_mask_type
        assert self.layer_number > 0
        self.attention_dropout = attention_dropout

        if self.apply_query_key_layer_scaling:
            self.convert_fp16_to_fp32_in_softmax = True

        assert (
            self.hidden_size % self.num_attention_heads == 0
        ), "hidden size is not divisible by the number of attention heads"

        self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads

        self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)

        # Strided linear layer.
        self.query_key_value = _Linear(
            hidden_size,
            3 * self.hidden_size,
        )

        self.coeff = None
        self.norm_factor = math.sqrt(self.hidden_size)

        if self.apply_query_key_layer_scaling:
            self.coeff = layer_number
            self.norm_factor *= self.coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            fp16,
            bf16,
            self.attn_mask_type,
            masked_softmax_fusion,
            self.attention_mask_func,
            self.convert_fp16_to_fp32_in_softmax,
            self.coeff,
        )

        self.attention_dropout = nn.Dropout(attention_dropout)

        # Output.
        self.dense = _Linear(hidden_size, hidden_size, bias=True, skip_bias_add=True)

    def forward(self, hidden_states, attention_mask):
        # hidden_states: [sub_seq_len, batch_size, hidden_size]
        # attention_mask: [batch_size, 1, sub_seq_len, seq_len]
        sub_seq_length, batch_size, hidden_size = hidden_states.size()

        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads shape change:
        # [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)]
        mixed_x_layer = self.query_key_value(hidden_states)

        # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size]
        new_tensor_shape = mixed_x_layer.size()[:-1] + (
            self.num_attention_heads,
            3 * self.hidden_size_per_attention_head,
        )
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # split into query, key and value
        last_dim = mixed_x_layer.dim() - 1
        last_dim_value = mixed_x_layer.size(-1)
        assert last_dim_value % 3 == 0, (
            "the last dimension is not a multiple of 3, " "cannot be divided into query, key and value"
        )
        partition_size = last_dim_value // 3
        (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim)

        # attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0) * self.world_size,
        )

        # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
        query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
        # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
        key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1)

        # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
        attention_scores = RingQK.apply(
            query_layer.transpose(0, 1).contiguous(),  # [batch_size * num_heads, sub_seq_len, head_size]
            key_layer.transpose(0, 1).contiguous(),  # [batch_size * num_heads, sub_seq_len, head_size],
            batch_size,
            self.num_attention_heads,
            sub_seq_length,
        )

        attention_scores /= self.norm_factor

        # change view to [batch_size, num_heads, sub_seq_len, seq_len]
        attention_scores = attention_scores.view(*output_size)

        # change shape to [batch_size, num_heads, sub_seq_len, seq_len]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with seed(ParallelMode.TENSOR):
            attention_probs = self.attention_dropout(attention_probs)

        # context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
        output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))

        # change view [sub_seq_len, batch_size * num_heads, head_size]
        value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1)

        # # change view [b * num_heads, sub_seq_len, seq_len]
        attention_probs = attention_probs.view(
            attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)
        )

        # matmul: [batch_size * num_heads, sub_seq_len, head_size]
        context_layer = RingAV.apply(
            attention_probs,
            value_layer.transpose(0, 1).contiguous(),
            batch_size,
            self.num_attention_heads,
            self.hidden_size_per_attention_head,
            sub_seq_length,
        )

        # change view [batch_size, num_heads, sub_seq_len, head_size]
        context_layer = context_layer.view(*output_size)

        # [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.hidden_size_per_attention_head * self.num_attention_heads,
        )
        context_layer = context_layer.view(*new_context_layer_shape)

        output, bias = self.dense(context_layer)

        return output, bias

    def __repr__(self):
        return (
            f"TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, "
            f"layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, "
            f"attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, "
            f"hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, "
            f"convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})"
        )


class _Linear(nn.Module):
    """Linear layer with column parallelism.
    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
        skip_bias_add: This was added to enable performance optimizations where bias
                       can be fused with other elementwise operations. we skip
                       adding bias but instead return it.
    """

    def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
        super(_Linear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add

        self.weight = Parameter(
            torch.empty(
                self.output_size,
                self.input_size,
            )
        )
        nn.init.xavier_normal_(self.weight)

        if bias:
            self.bias = Parameter(torch.empty(self.output_size))
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter("bias", None)

    def forward(self, input_):
        # Matrix multiply.
        bias = self.bias if not self.skip_bias_add else None
        output = F.linear(input_, self.weight, bias)

        if self.skip_bias_add:
            return output, self.bias
        else:
            return output

    def __repr__(self):
        return (
            f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
            + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
        )