shardformer fp8

This commit is contained in:
GuangyaoZhang
2024-07-08 07:04:48 +00:00
parent 51f916b11d
commit 457a0de79f
16 changed files with 520 additions and 234 deletions

View File

@@ -224,7 +224,10 @@ def main():
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
pp_size=1,
sp_size=2,
enable_sequence_parallelism=True,
sequence_parallelism_mode="all_to_all",
num_microbatches=None,
pp_style="interleaved",
num_model_chunks=2,

View File

@@ -5,7 +5,7 @@ pip install -r requirements.txt
FAIL_LIMIT=3
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
for plugin in "hybrid_parallel"; do
for i in $(seq 1 $FAIL_LIMIT); do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert" && break
echo "Failed $i times"

View File

@@ -218,8 +218,11 @@ def main():
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
tp_size=2,
pp_size=1,
sp_size=2,
sequence_parallelism_mode="split_gather",
enable_sequence_parallelism=True,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
@@ -318,3 +321,7 @@ def main():
if __name__ == "__main__":
main()
if dist.get_rank() == 0:
import pdb
pdb.set_trace()