This commit is contained in:
YeAnbang 2024-07-18 07:55:43 +00:00
commit 845ea7214e
21 changed files with 351 additions and 228 deletions

View File

@ -1 +1,3 @@
2.1.0-12.1.0 2.1.0-12.1.0
2.2.2-12.1.0
2.3.0-12.1.0

View File

@ -55,41 +55,27 @@ jobs:
steps: steps:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -U pip setuptools==68.2.2 wheel --user
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
BUILD_EXT=1 pip install -v . BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing - name: Unit Testing
run: | run: |
PYTHONPATH=$PWD pytest --durations=0 tests PYTHONPATH=$PWD pytest --durations=0 tests
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -49,42 +49,27 @@ jobs:
steps: steps:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2
with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -U pip setuptools==68.2.2 wheel --user
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
BUILD_EXT=1 pip install -v . BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing - name: Unit Testing
run: | run: |
PYTHONPATH=$PWD pytest --durations=0 tests PYTHONPATH=$PWD pytest --durations=0 tests
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -43,47 +43,28 @@ jobs:
steps: steps:
- name: Install dependencies - name: Install dependencies
run: | run: |
apt update && apt install -y cmake
pip install -U pip setuptools==68.2.2 wheel --user pip install -U pip setuptools==68.2.2 wheel --user
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
repository: hpcaitech/TensorNVMe
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
path: TensorNVMe
- name: Install tensornvme
run: |
cd TensorNVMe
apt update && apt install -y cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2
with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
- name: Download cub for CUDA 10.2
run: |
CUDA_VERSION=$(nvcc -V | awk -F ',| ' '/release/{print $6}')
# check if it is CUDA 10.2
# download cub
if [ "$CUDA_VERSION" = "10.2" ]; then
wget https://github.com/NVIDIA/cub/archive/refs/tags/1.8.0.zip
unzip 1.8.0.zip
cp -r cub-1.8.0/cub/ colossalai/kernel/cuda_native/csrc/kernels/include/
fi
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
BUILD_EXT=1 pip install -v . BUILD_EXT=1 pip install -v .
pip install -r requirements/requirements-test.txt pip install --no-cache-dir -r requirements/requirements-test.txt
- name: Install tensornvme
run: |
DISABLE_URING=1 pip install -v git+https://github.com/hpcaitech/TensorNVMe.git
- name: Unit Testing - name: Unit Testing
run: | run: |
PYTHONPATH=$PWD pytest --durations=0 tests PYTHONPATH=$PWD pytest --durations=0 tests
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors MOE_TENSOR_PATH: /data/scratch/moe_tensors

View File

@ -2,7 +2,7 @@ import ctypes
import random import random
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from types import MethodType from types import MethodType
@ -33,8 +33,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
@ -61,6 +64,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
use_ddp: bool, use_ddp: bool,
ddp_config: dict, ddp_config: dict,
custom_policy: Policy, custom_policy: Policy,
overlap_allgather: bool = False,
) -> None: ) -> None:
self.stage_manager = shard_config.pipeline_stage_manager self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config self.shard_config = shard_config
@ -69,6 +73,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.sp_group = sp_group self.sp_group = sp_group
self.use_dpp = use_ddp self.use_dpp = use_ddp
self.require_grad_sync = True self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
shardformer = ShardFormer(shard_config) shardformer = ShardFormer(shard_config)
if custom_policy is not None: if custom_policy is not None:
@ -106,6 +111,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = DDP(module, process_group=dp_group, **ddp_config) module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module) super().__init__(module)
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def sync_shared_params(self): def sync_shared_params(self):
for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
@ -197,7 +208,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs) with self._wait_all_gather():
return super().forward(*args, **kwargs)
def unwrap(self): def unwrap(self):
module = super().unwrap() module = super().unwrap()
@ -205,6 +217,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = module.module module = module.module
return module return module
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
@ -650,6 +669,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
): ):
self.model = model self.model = model
self.param_info = param_info self.param_info = param_info
@ -677,6 +697,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
dp_process_group=dp_process_group, dp_process_group=dp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
) )
def sync_dp_grads(self): def sync_dp_grads(self):
@ -992,6 +1013,7 @@ class HybridParallelPlugin(PipelinePluginBase):
make_vocab_size_divisible_by: int = 64, make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True, dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (
@ -1143,6 +1165,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2), partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision], forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
) )
self.max_norm = max_norm self.max_norm = max_norm
@ -1220,6 +1243,7 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp=use_ddp, use_ddp=use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0: if zero_stage == 0:
@ -1302,7 +1326,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead. # so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx: with ctx, model._wait_all_gather():
outputs = self.schedule.forward_backward_step( outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )

View File

@ -2,6 +2,7 @@ import enum
import logging import logging
import os import os
import warnings import warnings
from contextlib import nullcontext
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType from types import MethodType
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None: def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module) super().__init__(module)
self.dtype = None self.dtype = None
if precision == "fp16": if precision == "fp16":
@ -72,12 +76,25 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.convert_fn = None self.convert_fn = None
if self.dtype is not None: if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
self.op_hook = ZeroOpHook()
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
p.__init__(p, requires_grad=True)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs) ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
with ctx:
return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict) super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params() model.update_master_params()
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
load_sub_module: bool = True, load_sub_module: bool = True,
): ):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params() model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from peft import PeftModel from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap() peft_model = model.unwrap()
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
reduce_bucket_size_in_m: int = 12, reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = True,
overlap_allgather: bool = False,
cpu_offload: bool = False, cpu_offload: bool = False,
master_weights: bool = True, master_weights: bool = True,
verbose: bool = False, verbose: bool = False,
@ -315,6 +356,7 @@ class LowLevelZeroPlugin(DPPluginBase):
partition_grad=(stage == 2), partition_grad=(stage == 2),
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
master_weights=master_weights, master_weights=master_weights,
overlap_allgather=overlap_allgather,
) )
self.lora_enabled = False self.lora_enabled = False
self.verbose = verbose self.verbose = verbose
@ -431,7 +473,9 @@ class LowLevelZeroPlugin(DPPluginBase):
self.add_lora_params_to_optimizer(model, optimizer) self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision) model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)
# TODO: Support Galore + ZeRO # TODO: Support Galore + ZeRO
zero_stage = self.stage zero_stage = self.stage

View File

@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap() model = model.unwrap()
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
This argument should be manually set to False since params on same device might be stored in different files. This argument should be manually set to False since params on same device might be stored in different files.
""" """
assert isinstance(model, ModelWrapper), "Please boost the model before loading!" assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
model_before_wrapping = model # backup for model before wrapping model_before_wrapping = model # backup for model before wrapping
model = model.unwrap() model = model.unwrap()
@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap() model = model.unwrap()
if self.dp_rank != 0: if self.dp_rank != 0:
@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before loading!" assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
strict = False strict = False
model_before_wrapping = model model_before_wrapping = model
model = model.unwrap() model = model.unwrap()

View File

@ -91,7 +91,11 @@ def _broadcast_object_list(
my_rank = dist.get_rank() my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank. # Serialize object_list elements to tensors on src rank.
if my_rank == src: if my_rank == src:
if Version(torch.__version__) >= Version("1.13.0"): if Version(torch.__version__) >= Version("2.3.0"):
tensor_list, size_list = zip(
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
)
elif Version(torch.__version__) >= Version("1.13.0"):
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
else: else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
@ -276,7 +280,11 @@ def _send_recv_serialization_object(
send_object_tensor = None send_object_tensor = None
send_object_size_tensor = None send_object_size_tensor = None
if object is not None and send_dst is not None: if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"): if Version(torch.__version__) >= Version("2.3.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
object, device=current_device, group=send_group
)
elif Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
else: else:
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object) send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)

View File

@ -1,3 +1,4 @@
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -513,7 +514,6 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel # sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all": if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group) query_states = all_to_all_comm(query_states, sp_group)
@ -698,9 +698,9 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
next_decoder_cache = None next_decoder_cache = None
if sp_mode in ["ring", "split_gather"]: if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all": elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
for decoder_layer in self.layers: for decoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:

View File

@ -473,7 +473,7 @@ class LayoutConverter(metaclass=SingletonMeta):
for process_group in used_process_groups: for process_group in used_process_groups:
try: try:
dist.get_rank(process_group) dist.get_rank(process_group)
except RuntimeError as e: except (ValueError, RuntimeError) as e:
# If the group is not registered, it means it has been deleted # If the group is not registered, it means it has been deleted
if str(e) == ( if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"

View File

@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Dict, List from typing import Dict, List
from ..utils import merge_same_dim_mesh_list from ..utils import merge_same_dim_mesh_list
@ -23,10 +22,11 @@ class DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension. Otherwise, the element in shard_list means the data will be sharded in that dimension.
""" """
_DIFFERENCE_DICT = None
def __init__(self, shard_list): def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0 self.is_replica = len(shard_list) == 0
self.shard_list = shard_list self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other): def __eq__(self, other):
return str(self) == str(other) return str(self) == str(other)
@ -39,24 +39,43 @@ class DimSpec:
target += str(dim) target += str(dim)
return target return target
def _convert_str_to_shard_list(self, str_spec): @property
def difference_dict(self):
""" """
Convert str_spec into shard_list. Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def dim_diff(self, other):
"""
The difference between two DimSpec.
Argument: Argument:
str_spec(str): dim spec in str type. other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two DimSpec.
Example:
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.dim_diff(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R": @classmethod
return [] def _build_difference_2d_dict(cls):
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
""" """
Build a difference mapping for 2D device mesh case. It will be used to Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs. compute the difference between DimSpec pairs.
@ -67,9 +86,8 @@ class DimSpec:
difference_dict = {} difference_dict = {}
for source_spec in source_spec_list: for source_spec in source_spec_list:
for target_spec in target_spec_list: for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = cls._convert_str_to_shard_list(source_spec)
source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = cls._convert_str_to_shard_list(target_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target # source same as target
if source_shard_list == target_shard_list: if source_shard_list == target_shard_list:
@ -112,30 +130,27 @@ class DimSpec:
else: else:
difference = NAN difference = NAN
difference_dict[spec_pair] = difference difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict return difference_dict
def dim_diff(self, other): @staticmethod
def _convert_str_to_shard_list(str_spec):
""" """
The difference between two _DimSpec. Convert str_spec into shard_list.
Argument: Argument:
other(_DimSpec): the dim spec to compare with. str_spec(str): dim spec in str type.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpec: class ShardingSpec:

View File

@ -1,5 +1,4 @@
import operator import operator
from copy import deepcopy
from functools import reduce from functools import reduce
import torch import torch
@ -27,10 +26,11 @@ class _DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension. Otherwise, the element in shard_list means the data will be sharded in that dimension.
""" """
_DIFFERENCE_DICT = None
def __init__(self, shard_list): def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0 self.is_replica = len(shard_list) == 0
self.shard_list = shard_list self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other): def __eq__(self, other):
return str(self) == str(other) return str(self) == str(other)
@ -43,27 +43,46 @@ class _DimSpec:
target += str(dim) target += str(dim)
return target return target
def _convert_str_to_shard_list(self, str_spec): @property
def difference_dict(self):
""" """
Convert str_spec into shard_list. Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def difference(self, other):
"""
The difference between two _DimSpec.
Argument: Argument:
str_spec(str): dim spec in str type. other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R": @classmethod
return [] def _build_difference_2d_dict(cls):
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
""" """
Build a difference mapping for 2D device mesh case. It will be used to Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs. compute the difference between _DimSpec pairs.
""" """
source_spec_list = ["R", "S0", "S1", "S01"] source_spec_list = ["R", "S0", "S1", "S01"]
@ -71,9 +90,8 @@ class _DimSpec:
difference_dict = {} difference_dict = {}
for source_spec in source_spec_list: for source_spec in source_spec_list:
for target_spec in target_spec_list: for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = cls._convert_str_to_shard_list(source_spec)
source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = cls._convert_str_to_shard_list(target_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target # source same as target
if source_shard_list == target_shard_list: if source_shard_list == target_shard_list:
@ -116,30 +134,27 @@ class _DimSpec:
else: else:
difference = NAN difference = NAN
difference_dict[spec_pair] = difference difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict return difference_dict
def difference(self, other): @staticmethod
def _convert_str_to_shard_list(str_spec):
""" """
The difference between two _DimSpec. Convert str_spec into shard_list.
Argument: Argument:
other(_DimSpec): the dim spec to compare with. str_spec(str): dim spec in str type.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpecException(Exception): class ShardingSpecException(Exception):

View File

@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
overlap_allgather: bool = False,
): ):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# communication params # communication params
self._overlap_communication = overlap_communication self._overlap_communication = overlap_communication
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype self._communication_dtype = communication_dtype
@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# record the padding size of each param # record the padding size of each param
self._padding_map = dict() self._padding_map = dict()
# padded working param is all-gather buffer and it shares the same memory with working param
self._working_param_to_padded_working_param = dict()
# mapping working param and master param # mapping working param and master param
self.master_to_working_param = dict() self.master_to_working_param = dict()
@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with torch.no_grad(): with torch.no_grad():
if padding_size > 0: if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights # # reset working params' ptr when no master weights
if self._master_weights == False: # if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape) param.data = padding_param[: param.numel()].view(param.shape)
else: else:
padding_param = param.data.view(-1) padding_param = param.data.view(-1)
self._working_param_to_padded_working_param[param] = padding_param
splited_params = padding_param.split( splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# use fp32 when master_weights is True # use fp32 when master_weights is True
if self._master_weights is True: if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device) splited_param_current_rank = splited_params.detach().clone().float().to(device)
else: else:
splited_param_current_rank = splited_params splited_param_current_rank = splited_params
@ -549,22 +555,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype) param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param] pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size: padded_working_param = self._working_param_to_padded_working_param[working_param]
buffer_tensor = torch.empty_like( if self._overlap_allgather:
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
) set_all_gather_handle(working_param, handle)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) else:
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
continue dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
try: continue
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) try:
except RuntimeError: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.pg_to_tensor_bucket[pg].all_gather(pg) except RuntimeError:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): if not self._overlap_allgather:
if not tensor_bucket.is_empty(): for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
tensor_bucket.all_gather(pg) if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""
@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id] grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id) return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)

View File

@ -0,0 +1,33 @@
from typing import List
from torch._tensor import Tensor
from colossalai.tensor.param_op_hook import ColoParamOpHook
_ALL_GATHER_HANDLE = "_all_gather_handle"
def wait_all_gather_handle(p):
if hasattr(p, _ALL_GATHER_HANDLE):
handle = getattr(p, _ALL_GATHER_HANDLE)
handle.wait()
delattr(p, _ALL_GATHER_HANDLE)
def set_all_gather_handle(p, handle):
setattr(p, _ALL_GATHER_HANDLE, handle)
class ZeroOpHook(ColoParamOpHook):
def pre_forward(self, params: List[Tensor]) -> None:
for p in params:
wait_all_gather_handle(p)
def post_forward(self, params: List[Tensor]) -> None:
pass
def pre_backward(self, params: List[Tensor]) -> None:
pass
def post_backward(self, params: List[Tensor]) -> None:
pass

View File

@ -98,6 +98,7 @@ def main():
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--overlap_allgather", action="store_true")
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -199,9 +200,9 @@ def main():
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
dp_outside=False,
overlap_p2p=args.overlap, overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache, enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
**hybrid_kwargs, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":

View File

@ -113,13 +113,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable: if self.disable:
return return
get_accelerator().synchronize() # get_accelerator().synchronize()
self.timer.start() self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None: def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable: if self.disable:
return return
get_accelerator().synchronize() # get_accelerator().synchronize()
self.timer.end() self.timer.end()
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape

View File

@ -8,7 +8,7 @@ click
fabric fabric
contexttimer contexttimer
ninja ninja
torch>=2.1.0,<2.3.0 torch>=2.1.0,<=2.3.0
safetensors safetensors
einops einops
pydantic pydantic

View File

@ -135,51 +135,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{ # Ulysess + Flash attention { # Ulysess + Flash attention
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
@ -242,6 +197,54 @@ def run_qwen2_test(test_config):
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
@ -259,7 +262,11 @@ def run_qwen2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2") sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter() clear_layout_converter()
Randomizer.reset_index() Randomizer.reset_index()

View File

@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer.step() zero1_optimizer.step()
zero2_optimizer.step() zero2_optimizer.step()
zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()
# check updated param # check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data) assert torch.equal(z1p.data, z2p.data)

View File

@ -177,6 +177,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step # torch ddp step
torch_optimizer.step() torch_optimizer.step()
zero_optimizer._force_wait_all_gather()
# check updated param # check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype) loose_close(p, z1p, dtype=dtype)

View File

@ -1 +1 @@
0.4.0 0.4.1