From 81272e9d005ab2aeebe5ddd24d671d3dd93e2b61 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sat, 17 Aug 2024 09:37:37 +0000
Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
---
 .../coati/dataset/tokenization_utils.py       |  2 +-
 .../ColossalChat/coati/trainer/dpo.py         |  2 +-
 .../ColossalChat/coati/trainer/kto.py         |  2 +-
 .../ColossalChat/coati/trainer/orpo.py        |  2 +-
 applications/ColossalChat/examples/README.md  |  2 +-
 .../examples/training_scripts/train_kto.py    |  2 +-
 .../examples/training_scripts/train_orpo.py   |  2 +-
 applications/ColossalChat/requirements.txt    |  2 +-
 applications/ColossalChat/tests/test_train.sh |  2 +-
 .../booster/plugin/low_level_zero_plugin.py   |  7 ++++++-
 colossalai/shardformer/layer/loss.py          |  1 +
 colossalai/shardformer/modeling/chatglm2.py   |  5 -----
 colossalai/shardformer/modeling/command.py    |  2 +-
 colossalai/shardformer/modeling/deepseek.py   |  6 ++++--
 colossalai/shardformer/modeling/llama.py      | 20 +++++++++++++------
 colossalai/shardformer/modeling/mixtral.py    |  4 +++-
 colossalai/shardformer/policies/mixtral.py    |  2 +-
 17 files changed, 39 insertions(+), 26 deletions(-)

diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py
index d2bc38aa0..020432b9e 100755
--- a/applications/ColossalChat/coati/dataset/tokenization_utils.py
+++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py
@@ -392,4 +392,4 @@ def tokenize_kto(
         "label": data_point["label"],
         "input_id_decode": decoded_full_prompt,
         "completion_decode": decoded_completion,
-    }
\ No newline at end of file
+    }
diff --git a/applications/ColossalChat/coati/trainer/dpo.py b/applications/ColossalChat/coati/trainer/dpo.py
index c6a174317..24ddca654 100755
--- a/applications/ColossalChat/coati/trainer/dpo.py
+++ b/applications/ColossalChat/coati/trainer/dpo.py
@@ -356,4 +356,4 @@ class DPOTrainer(SLTrainer):
         os.makedirs(self.save_dir, exist_ok=True)
         with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
             f.write(msg)
-        step_bar.close()
\ No newline at end of file
+        step_bar.close()
diff --git a/applications/ColossalChat/coati/trainer/kto.py b/applications/ColossalChat/coati/trainer/kto.py
index 11fda0aa8..6462ba816 100755
--- a/applications/ColossalChat/coati/trainer/kto.py
+++ b/applications/ColossalChat/coati/trainer/kto.py
@@ -346,4 +346,4 @@ class KTOTrainer(SLTrainer):
         os.makedirs(self.save_dir, exist_ok=True)
         with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
             f.write(msg)
-        step_bar.close()
\ No newline at end of file
+        step_bar.close()
diff --git a/applications/ColossalChat/coati/trainer/orpo.py b/applications/ColossalChat/coati/trainer/orpo.py
index 948232bff..c2f75771c 100644
--- a/applications/ColossalChat/coati/trainer/orpo.py
+++ b/applications/ColossalChat/coati/trainer/orpo.py
@@ -323,4 +323,4 @@ class ORPOTrainer(SLTrainer):
         os.makedirs(self.save_dir, exist_ok=True)
         with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
             f.write(msg)
-        step_bar.close()
\ No newline at end of file
+        step_bar.close()
diff --git a/applications/ColossalChat/examples/README.md b/applications/ColossalChat/examples/README.md
index 7fd6011de..fec7bc061 100755
--- a/applications/ColossalChat/examples/README.md
+++ b/applications/ColossalChat/examples/README.md
@@ -903,4 +903,4 @@ For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/mai
 ## Attention
 
 
-The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
\ No newline at end of file
+The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
diff --git a/applications/ColossalChat/examples/training_scripts/train_kto.py b/applications/ColossalChat/examples/training_scripts/train_kto.py
index db48e33e3..598fd8062 100755
--- a/applications/ColossalChat/examples/training_scripts/train_kto.py
+++ b/applications/ColossalChat/examples/training_scripts/train_kto.py
@@ -375,4 +375,4 @@ if __name__ == "__main__":
         os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
         with open(args.config_file, "w") as f:
             json.dump(args.__dict__, f, indent=4)
-    train(args)
\ No newline at end of file
+    train(args)
diff --git a/applications/ColossalChat/examples/training_scripts/train_orpo.py b/applications/ColossalChat/examples/training_scripts/train_orpo.py
index fd99af828..87860f7ea 100755
--- a/applications/ColossalChat/examples/training_scripts/train_orpo.py
+++ b/applications/ColossalChat/examples/training_scripts/train_orpo.py
@@ -340,4 +340,4 @@ if __name__ == "__main__":
         os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
         with open(args.config_file, "w") as f:
             json.dump(args.__dict__, f, indent=4)
-    train(args)
\ No newline at end of file
+    train(args)
diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt
index 1978120d1..ac40ae821 100755
--- a/applications/ColossalChat/requirements.txt
+++ b/applications/ColossalChat/requirements.txt
@@ -20,4 +20,4 @@ datasets
 ninja==1.11.1
 sentencepiece==0.1.99
 flash-attn
-tiktoken
\ No newline at end of file
+tiktoken
diff --git a/applications/ColossalChat/tests/test_train.sh b/applications/ColossalChat/tests/test_train.sh
index 3c47836d2..69036de63 100755
--- a/applications/ColossalChat/tests/test_train.sh
+++ b/applications/ColossalChat/tests/test_train.sh
@@ -640,4 +640,4 @@ for lora_rank in ${LORA_RANK[@]}; do
             fi
         done
     done
-done
\ No newline at end of file
+done
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 533a2004d..448fb9e21 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -64,7 +64,12 @@ class OptimizerParamCheckState(enum.Enum):
 
 class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
     def __init__(
-        self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False
+        self,
+        module: nn.Module,
+        precision: str,
+        overlap_allgather: bool = False,
+        cast_inputs: bool = True,
+        use_fp8: bool = False,
     ) -> None:
         super().__init__(module)
         self.dtype = None
diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py
index 41503a6bf..12df824d1 100644
--- a/colossalai/shardformer/layer/loss.py
+++ b/colossalai/shardformer/layer/loss.py
@@ -3,6 +3,7 @@ import torch.distributed as dist
 from torch.autograd import Function
 from torch.distributed import ProcessGroup
 from torch.nn import CrossEntropyLoss
+
 from colossalai.shardformer.layer._operation import reduce_forward
 from colossalai.shardformer.shard import ShardConfig
 
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
index 8d8608866..a761968af 100644
--- a/colossalai/shardformer/modeling/chatglm2.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -16,11 +16,6 @@ from colossalai.shardformer.layer._operation import (
     gather_forward_split_backward,
     split_forward_gather_backward,
 )
-from colossalai.shardformer.layer._operation import (
-    all_to_all_comm,
-    gather_forward_split_backward,
-    split_forward_gather_backward,
-)
 
 
 def get_flash_core_attention_forward():
diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py
index b3adcc5d4..530338394 100644
--- a/colossalai/shardformer/modeling/command.py
+++ b/colossalai/shardformer/modeling/command.py
@@ -24,7 +24,7 @@ from colossalai.shardformer.layer._operation import (
 )
 from colossalai.shardformer.shard import ShardConfig
 
-from ..layer import ColoAttention, dist_cross_entropy, cross_entropy_1d
+from ..layer import ColoAttention, dist_cross_entropy
 
 _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
 
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index bec85bd58..59f9d4516 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -145,7 +145,9 @@ class EPDeepseekMoE(nn.Module):
         output_split_sizes = torch.zeros_like(input_split_sizes)
 
         # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
-        dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication)
+        dist.all_to_all_single(
+            output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
+        )
 
         with torch.no_grad():
             activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
@@ -694,7 +696,7 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
 
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
-        
+
         # TODO: upgrade transformers to 4.44.0 to fix the bug, remove the hard code.
         self._use_flash_attention_2 = shard_config.enable_flash_attention
         self._use_sdpa = False if shard_config.enable_flash_attention else self._use_sdpa
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 4fdca2d3c..71d8daa35 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -26,11 +26,15 @@ from transformers.utils import logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
 from colossalai.shardformer.layer import AttnMaskType
-from colossalai.shardformer.layer._operation import all_to_all_comm, gather_forward_split_backward, split_forward_gather_backward
+from colossalai.shardformer.layer._operation import (
+    all_to_all_comm,
+    gather_forward_split_backward,
+    split_forward_gather_backward,
+)
 from colossalai.shardformer.layer.utils import is_share_sp_tp, split_batch_zigzag
 from colossalai.shardformer.shard import ShardConfig
 
-from ..layer import ColoAttention, RingAttention, dist_cross_entropy, cross_entropy_1d
+from ..layer import ColoAttention, RingAttention, dist_cross_entropy
 
 _SUPPORTED_SP_MODE = ["all_to_all", "split_gather", "ring", "ring_attn"]
 
@@ -162,9 +166,13 @@ class LlamaPipelineForwards:
                     hidden_states, position_ids = split_batch_zigzag([hidden_states, position_ids], sp_group)
 
             elif is_share_sp_tp(sp_mode):
-                hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication)
+                hidden_states = split_forward_gather_backward(
+                    hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+                )
             elif sp_mode == "all_to_all":
-                hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication)
+                hidden_states = split_forward_gather_backward(
+                    hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
+                )
 
         if self.gradient_checkpointing and self.training and use_cache:
             if use_cache:
@@ -355,7 +363,7 @@ class LlamaPipelineForwards:
             loss = dist_cross_entropy(
                 labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
             )
-            
+
             if not return_dict:
                 output = (logits,) + outputs[1:]
                 return (loss,) + output if loss is not None else output
@@ -675,7 +683,7 @@ def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=N
 
         past_seen_tokens = 0
         seq_len = inputs_embeds.shape[1]
-        batch_size = inputs_embeds.shape[0]
+        inputs_embeds.shape[0]
         if use_cache:  # kept for BC (cache positions)
             if not isinstance(past_key_values, StaticCache):
                 past_key_values = DynamicCache.from_legacy_cache(past_key_values)
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index c488d3cc4..50334677e 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -691,7 +691,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
         # sp: all-to-all comminucation when introducing sequence parallel
         if sp_mode == "all_to_all":
             attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()  # (1, 8, 128)
-            attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication)  # (1, 4, 256)
+            attn_output = all_to_all_comm(
+                attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
+            )  # (1, 4, 256)
         else:
             attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 1f19ff65d..4bdca78cb 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -5,7 +5,7 @@ from typing import Callable, Dict, List, Union
 import torch.nn as nn
 from torch import Tensor
 from torch.nn import Module
-from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM, MixtralModel
 
 from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
 from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D