[feat] Linear1D_COL/ROW support zbv WeightGradStore;

This commit is contained in:
duanjunwen
2024-10-14 07:02:43 +00:00
parent 0ca16d5cbe
commit cfade4c36d
7 changed files with 820 additions and 28 deletions

View File

@@ -5,6 +5,8 @@ import warnings
from contextlib import nullcontext
import torch
torch.autograd.set_detect_anomaly(True)
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
@@ -251,6 +253,7 @@ def main():
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
make_vocab_size_divisible_by=1,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":