mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
Merge branch 'main' of https://github.com/hpcaitech/ColossalAI into kto
This commit is contained in:
commit
845ea7214e
@ -1 +1,3 @@
|
||||
2.1.0-12.1.0
|
||||
2.2.2-12.1.0
|
||||
2.3.0-12.1.0
|
||||
|
@ -55,41 +55,27 @@ jobs:
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
apt update && apt install -y cmake
|
||||
pip install -r requirements.txt
|
||||
DISABLE_URING=1 pip install -v .
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Download cub for CUDA 10.2
|
||||
run: |
|
||||
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
|
||||
|
||||
# check if it is CUDA 10.2
|
||||
# download cub
|
||||
if [ "$CUDA_VERSION" = "10.2" ]; then
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
fi
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --durations=0 tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
33
.github/workflows/compatiblity_test_on_pr.yml
vendored
33
.github/workflows/compatiblity_test_on_pr.yml
vendored
@ -49,42 +49,27 @@ jobs:
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
apt update && apt install -y cmake
|
||||
pip install -r requirements.txt
|
||||
DISABLE_URING=1 pip install -v .
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
- name: Download cub for CUDA 10.2
|
||||
run: |
|
||||
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
|
||||
|
||||
# check if it is CUDA 10.2
|
||||
# download cub
|
||||
if [ "$CUDA_VERSION" = "10.2" ]; then
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
fi
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --durations=0 tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
@ -43,47 +43,28 @@ jobs:
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt update && apt install -y cmake
|
||||
pip install -U pip setuptools==68.2.2 wheel --user
|
||||
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
repository: hpcaitech/TensorNVMe
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
path: TensorNVMe
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
cd TensorNVMe
|
||||
apt update && apt install -y cmake
|
||||
pip install -r requirements.txt
|
||||
DISABLE_URING=1 pip install -v .
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
|
||||
|
||||
- name: Download cub for CUDA 10.2
|
||||
run: |
|
||||
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
|
||||
|
||||
# check if it is CUDA 10.2
|
||||
# download cub
|
||||
if [ "$CUDA_VERSION" = "10.2" ]; then
|
||||
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
|
||||
unzip 1.8.0.zip
|
||||
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
|
||||
fi
|
||||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v .
|
||||
pip install -r requirements/requirements-test.txt
|
||||
pip install --no-cache-dir -r requirements/requirements-test.txt
|
||||
|
||||
- name: Install tensornvme
|
||||
run: |
|
||||
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
|
||||
|
||||
- name: Unit Testing
|
||||
run: |
|
||||
PYTHONPATH=$PWD pytest --durations=0 tests
|
||||
env:
|
||||
DATA: /data/scratch/cifar-10
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||
LLAMA_PATH: /data/scratch/llama-tiny
|
||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||
|
||||
|
@ -2,7 +2,7 @@ import ctypes
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
@ -33,8 +33,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
@ -61,6 +64,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
use_ddp: bool,
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
@ -69,6 +73,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
self.sp_group = sp_group
|
||||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
@ -106,6 +111,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=True)
|
||||
|
||||
def sync_shared_params(self):
|
||||
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
|
||||
@ -197,7 +208,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
with self._wait_all_gather():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
@ -205,6 +217,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
module = module.module
|
||||
return module
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def _wait_all_gather(self):
|
||||
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
@ -650,6 +669,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
@ -677,6 +697,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
cpu_offload=cpu_offload,
|
||||
dp_process_group=dp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
def sync_dp_grads(self):
|
||||
@ -992,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
@ -1143,6 +1165,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
@ -1220,6 +1243,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
@ -1302,7 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
# so we disable it, performing manual reduction instead.
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
with ctx:
|
||||
with ctx, model._wait_all_gather():
|
||||
outputs = self.schedule.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
@ -2,6 +2,7 @@ import enum
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == "fp16":
|
||||
@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
with ctx:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
return super().save_sharded_model(
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
reduce_bucket_size_in_m: int = 12,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
@ -315,6 +356,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
partition_grad=(stage == 2),
|
||||
cpu_offload=cpu_offload,
|
||||
master_weights=master_weights,
|
||||
overlap_allgather=overlap_allgather,
|
||||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
@ -431,7 +473,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
self.add_lora_params_to_optimizer(model, optimizer)
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
model = LowLevelZeroModel(
|
||||
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
zero_stage = self.stage
|
||||
|
@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
model_before_wrapping = model # backup for model before wrapping
|
||||
model = model.unwrap()
|
||||
|
||||
@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
|
||||
if self.dp_rank != 0:
|
||||
@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
strict = False
|
||||
model_before_wrapping = model
|
||||
model = model.unwrap()
|
||||
|
@ -91,7 +91,11 @@ def _broadcast_object_list(
|
||||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
if Version(torch.__version__) >= Version("2.3.0"):
|
||||
tensor_list, size_list = zip(
|
||||
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
|
||||
)
|
||||
elif Version(torch.__version__) >= Version("1.13.0"):
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
|
||||
else:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
@ -276,7 +280,11 @@ def _send_recv_serialization_object(
|
||||
send_object_tensor = None
|
||||
send_object_size_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
if Version(torch.__version__) >= Version("2.3.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
|
||||
object, device=current_device, group=send_group
|
||||
)
|
||||
elif Version(torch.__version__) >= Version("1.13.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||
else:
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -513,7 +514,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
@ -698,9 +698,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
next_decoder_cache = None
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
|
@ -473,7 +473,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
for process_group in used_process_groups:
|
||||
try:
|
||||
dist.get_rank(process_group)
|
||||
except RuntimeError as e:
|
||||
except (ValueError, RuntimeError) as e:
|
||||
# If the group is not registered, it means it has been deleted
|
||||
if str(e) == (
|
||||
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
|
||||
|
@ -1,4 +1,3 @@
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List
|
||||
|
||||
from ..utils import merge_same_dim_mesh_list
|
||||
@ -23,10 +22,11 @@ class DimSpec:
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
"""
|
||||
|
||||
_DIFFERENCE_DICT = None
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
self.build_difference_2d_dict()
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
@ -39,24 +39,43 @@ class DimSpec:
|
||||
target += str(dim)
|
||||
return target
|
||||
|
||||
def _convert_str_to_shard_list(self, str_spec):
|
||||
@property
|
||||
def difference_dict(self):
|
||||
"""
|
||||
Convert str_spec into shard_list.
|
||||
Returns the difference dict, and lazily initializes it when needed
|
||||
|
||||
Return:
|
||||
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
|
||||
difference dict
|
||||
"""
|
||||
if self._DIFFERENCE_DICT is None:
|
||||
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
|
||||
|
||||
return self._DIFFERENCE_DICT
|
||||
|
||||
def dim_diff(self, other):
|
||||
"""
|
||||
The difference between two DimSpec.
|
||||
|
||||
Argument:
|
||||
str_spec(str): dim spec in str type.
|
||||
other(DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = DimSpec([0])
|
||||
other_dim_spec = DimSpec([0, 1])
|
||||
print(dim_spec.dim_diff(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
@classmethod
|
||||
def _build_difference_2d_dict(cls):
|
||||
"""
|
||||
Build a difference mapping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
@ -67,9 +86,8 @@ class DimSpec:
|
||||
difference_dict = {}
|
||||
for source_spec in source_spec_list:
|
||||
for target_spec in target_spec_list:
|
||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
||||
source_shard_list = cls._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = cls._convert_str_to_shard_list(target_spec)
|
||||
|
||||
# source same as target
|
||||
if source_shard_list == target_shard_list:
|
||||
@ -112,30 +130,27 @@ class DimSpec:
|
||||
|
||||
else:
|
||||
difference = NAN
|
||||
difference_dict[spec_pair] = difference
|
||||
difference_dict[(source_spec, target_spec)] = difference
|
||||
|
||||
self.difference_dict = difference_dict
|
||||
return difference_dict
|
||||
|
||||
def dim_diff(self, other):
|
||||
@staticmethod
|
||||
def _convert_str_to_shard_list(str_spec):
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
Convert str_spec into shard_list.
|
||||
|
||||
Argument:
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
str_spec(str): dim spec in str type.
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
|
||||
class ShardingSpec:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
@ -27,10 +26,11 @@ class _DimSpec:
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
"""
|
||||
|
||||
_DIFFERENCE_DICT = None
|
||||
|
||||
def __init__(self, shard_list):
|
||||
self.is_replica = len(shard_list) == 0
|
||||
self.shard_list = shard_list
|
||||
self.build_difference_2d_dict()
|
||||
|
||||
def __eq__(self, other):
|
||||
return str(self) == str(other)
|
||||
@ -43,27 +43,46 @@ class _DimSpec:
|
||||
target += str(dim)
|
||||
return target
|
||||
|
||||
def _convert_str_to_shard_list(self, str_spec):
|
||||
@property
|
||||
def difference_dict(self):
|
||||
"""
|
||||
Convert str_spec into shard_list.
|
||||
Returns the difference dict, and lazily initializes it when needed
|
||||
|
||||
Return:
|
||||
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
|
||||
difference dict
|
||||
"""
|
||||
if self._DIFFERENCE_DICT is None:
|
||||
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
|
||||
|
||||
return self._DIFFERENCE_DICT
|
||||
|
||||
def difference(self, other):
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
|
||||
Argument:
|
||||
str_spec(str): dim spec in str type.
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
@classmethod
|
||||
def _build_difference_2d_dict(cls):
|
||||
"""
|
||||
Build a difference mapping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
compute the difference between _DimSpec pairs.
|
||||
"""
|
||||
|
||||
source_spec_list = ["R", "S0", "S1", "S01"]
|
||||
@ -71,9 +90,8 @@ class _DimSpec:
|
||||
difference_dict = {}
|
||||
for source_spec in source_spec_list:
|
||||
for target_spec in target_spec_list:
|
||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
||||
source_shard_list = cls._convert_str_to_shard_list(source_spec)
|
||||
target_shard_list = cls._convert_str_to_shard_list(target_spec)
|
||||
|
||||
# source same as target
|
||||
if source_shard_list == target_shard_list:
|
||||
@ -116,30 +134,27 @@ class _DimSpec:
|
||||
|
||||
else:
|
||||
difference = NAN
|
||||
difference_dict[spec_pair] = difference
|
||||
difference_dict[(source_spec, target_spec)] = difference
|
||||
|
||||
self.difference_dict = difference_dict
|
||||
return difference_dict
|
||||
|
||||
def difference(self, other):
|
||||
@staticmethod
|
||||
def _convert_str_to_shard_list(str_spec):
|
||||
"""
|
||||
The difference between two _DimSpec.
|
||||
Convert str_spec into shard_list.
|
||||
|
||||
Argument:
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
|
||||
Output:
|
||||
5
|
||||
str_spec(str): dim spec in str type.
|
||||
"""
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
|
||||
if str_spec == "R":
|
||||
return []
|
||||
if str_spec == "S0":
|
||||
return [0]
|
||||
if str_spec == "S1":
|
||||
return [1]
|
||||
if str_spec == "S01":
|
||||
return [0, 1]
|
||||
|
||||
|
||||
class ShardingSpecException(Exception):
|
||||
|
@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
||||
@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
self._overlap_allgather = overlap_allgather
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# record the padding size of each param
|
||||
self._padding_map = dict()
|
||||
# padded working param is all-gather buffer and it shares the same memory with working param
|
||||
self._working_param_to_padded_working_param = dict()
|
||||
|
||||
# mapping working param and master param
|
||||
self.master_to_working_param = dict()
|
||||
@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
with torch.no_grad():
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
# reset working params' ptr when no master weights
|
||||
if self._master_weights == False:
|
||||
param.data = padding_param[: param.numel()].view(param.shape)
|
||||
# # reset working params' ptr when no master weights
|
||||
# if self._master_weights == False:
|
||||
param.data = padding_param[: param.numel()].view(param.shape)
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
self._working_param_to_padded_working_param[param] = padding_param
|
||||
|
||||
splited_params = padding_param.split(
|
||||
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
|
||||
@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# use fp32 when master_weights is True
|
||||
if self._master_weights is True:
|
||||
splited_param_current_rank = splited_params.detach().float().to(device)
|
||||
splited_param_current_rank = splited_params.detach().clone().float().to(device)
|
||||
else:
|
||||
splited_param_current_rank = splited_params
|
||||
|
||||
@ -549,22 +555,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
param_to_gather = master_param.to(device).to(self._dtype)
|
||||
pg = self.param_to_pg[working_param]
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
buffer_tensor = torch.empty_like(
|
||||
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||
)
|
||||
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg)
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||
if self._overlap_allgather:
|
||||
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
set_all_gather_handle(working_param, handle)
|
||||
else:
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg)
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg)
|
||||
if not self._overlap_allgather:
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg)
|
||||
|
||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
grad_store = self.pid_to_grad_store[param_id]
|
||||
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for param in self._working_param_to_padded_working_param.keys():
|
||||
wait_all_gather_handle(param)
|
||||
|
33
colossalai/zero/low_level/zero_hook.py
Normal file
33
colossalai/zero/low_level/zero_hook.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import List
|
||||
|
||||
from torch._tensor import Tensor
|
||||
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
_ALL_GATHER_HANDLE = "_all_gather_handle"
|
||||
|
||||
|
||||
def wait_all_gather_handle(p):
|
||||
if hasattr(p, _ALL_GATHER_HANDLE):
|
||||
handle = getattr(p, _ALL_GATHER_HANDLE)
|
||||
handle.wait()
|
||||
delattr(p, _ALL_GATHER_HANDLE)
|
||||
|
||||
|
||||
def set_all_gather_handle(p, handle):
|
||||
setattr(p, _ALL_GATHER_HANDLE, handle)
|
||||
|
||||
|
||||
class ZeroOpHook(ColoParamOpHook):
|
||||
def pre_forward(self, params: List[Tensor]) -> None:
|
||||
for p in params:
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def post_forward(self, params: List[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params: List[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params: List[Tensor]) -> None:
|
||||
pass
|
@ -98,6 +98,7 @@ def main():
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
@ -199,9 +200,9 @@ def main():
|
||||
enable_flash_attention=args.xformers,
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
|
@ -113,13 +113,13 @@ class PerformanceEvaluator:
|
||||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
get_accelerator().synchronize()
|
||||
# get_accelerator().synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
get_accelerator().synchronize()
|
||||
# get_accelerator().synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
@ -8,7 +8,7 @@ click
|
||||
fabric
|
||||
contexttimer
|
||||
ninja
|
||||
torch>=2.1.0,<2.3.0
|
||||
torch>=2.1.0,<=2.3.0
|
||||
safetensors
|
||||
einops
|
||||
pydantic
|
||||
|
@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_qwen2_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
@ -242,6 +197,54 @@ def run_qwen2_test(test_config):
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_qwen2_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
|
||||
zero1_optimizer.step()
|
||||
zero2_optimizer.step()
|
||||
|
||||
zero1_optimizer._force_wait_all_gather()
|
||||
zero2_optimizer._force_wait_all_gather()
|
||||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert not hasattr(z1p, "_all_gather_handle")
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
|
@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
||||
# torch ddp step
|
||||
torch_optimizer.step()
|
||||
|
||||
zero_optimizer._force_wait_all_gather()
|
||||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
loose_close(p, z1p, dtype=dtype)
|
||||
|
@ -1 +1 @@
|
||||
0.4.0
|
||||
0.4.1
|
||||
|
Loading…
Reference in New Issue
Block a user