diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 7dce6e6da..443c4094c 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -6,6 +6,7 @@ from torch import Tensor from torch.optim import Optimizer from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils.device import autocast from .mixed_precision_base import MixedPrecision @@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper): super().__init__(module) def forward(self, *args, **kwargs): - with torch.cuda.amp.autocast(): + with autocast(): return self.module(*args, **kwargs) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f9716ab97..bbc36ceab 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.utils.device import get_current_device from .pp_plugin_base import PipelinePluginBase @@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper): self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) - module = module.cuda() + module = module.to(get_current_device()) # setting input type cast when using mixed precision self.convert_fn = None @@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) if self.pp_size > 1: @@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper): total_norm_exponentiated += grad_norm_exponentiated - total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) if self.tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) @@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): # so we need to calculate the norm of 'tp' and 'pp' gradients. total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) if self.tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): total_norm_exponentiated += grad_norm_exponentiated - total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) if self.tp_size > 1: # compute norm in tp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) @@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): # so we only need to calculate the norm 'tp' of 'pp' gradients. total_norm = super()._compute_grad_norm(gradients, norm_type) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32) if tp_size > 1: dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) @@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): total_norm_exponentiated += grad_norm_exponentiated - total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32) if dp_size > 1: # compute norm in dp process group dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) @@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase): return self.pp_size > 1 def supported_devices(self) -> List[str]: - return ["cuda"] + return ["cuda", "npu"] def supported_precisions(self) -> List[str]: return ["fp16", "bf16", "fp32"] diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 72f199203..3949590e8 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -38,7 +38,7 @@ class DeviceMesh: device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ - _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} + _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"} def __init__( self, diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index ced5cc3e6..0a8d09be2 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -1,7 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch.cuda.amp as torch_amp +from colossalai.utils.device import autocast + import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss @@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module): super().__init__() self.model = model - @torch_amp.autocast() + @autocast() def forward(self, *args, **kwargs): """ Execute forward under the torch amp context @@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module): super().__init__() self.loss = loss - @torch_amp.autocast() + @autocast() def forward(self, *args, **kwargs): """ Execute forward under the torch amp context diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index 387e1c54e..9a8051ae9 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -7,7 +7,7 @@ import torch from torch.utils.checkpoint import check_backward_validity, detach_variable from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states -from colossalai.utils import get_current_device +from colossalai.utils.device import autocast, get_current_device def copy_to_device(obj, device): @@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function): inputs[idx] = tensors[i] detached_inputs = detach_variable(tuple(inputs)) if ctx.had_autocast_in_fwd: - with torch.enable_grad(), torch.cuda.amp.autocast(): + with torch.enable_grad(), autocast(): outputs = ctx.run_function(*detached_inputs) else: with torch.enable_grad(): @@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks( + with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( inner_pack, inner_unpack ): _unused = function(*args) diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 7421f84bf..4b6343adc 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -6,6 +6,7 @@ import torch.distributed as dist from torch import nn from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup, get_world_size +from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed class SeqParallelUtils: @@ -104,14 +105,14 @@ class Randomizer: def __init__(self, seed: int): self.seed = seed - # Handle CUDA rng state + # Handle device rng state # 1. get the current rng state # 2. set the seed and store the rng state # 3. recover the original rng state - cuda_original_rng_state = torch.cuda.get_rng_state() - torch.cuda.manual_seed(seed) - self.cuda_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(cuda_original_rng_state) + device_original_rng_state = get_rng_state() + manual_seed(seed) + self.device_rng_state = get_rng_state() + set_rng_state(device_original_rng_state) # to the same for cpu rng state cpu_original_rng_state = torch.get_rng_state() @@ -119,11 +120,11 @@ class Randomizer: self.cpu_rng_state = torch.get_rng_state() torch.set_rng_state(cpu_original_rng_state) - def _set_cuda_rng_state(self, rng_state): - torch.cuda.set_rng_state(rng_state) + def _set_device_rng_state(self, rng_state): + set_rng_state(rng_state) - def _get_cuda_rng_state(self): - current_state = torch.cuda.get_rng_state() + def _get_device_rng_state(self): + current_state = get_rng_state() return current_state def _set_cpu_rng_state(self, rng_state): @@ -144,16 +145,16 @@ class Randomizer: >>> input = super().forward(input) """ try: - current_cuda_rng_state = self._get_cuda_rng_state() - self._set_cuda_rng_state(self.cuda_rng_state) + current_device_rng_state = self._get_device_rng_state() + self._set_device_rng_state(self.device_rng_state) if enable_cpu: current_cpu_rng_state = self._get_cpu_rng_state() self._set_cpu_rng_state(self.cpu_rng_state) yield finally: - self.cuda_rng_state = self._get_cuda_rng_state() - self._set_cuda_rng_state(current_cuda_rng_state) + self.device_rng_state = self._get_device_rng_state() + self._set_device_rng_state(current_device_rng_state) if enable_cpu: self.cpu_rng_state = self._get_cpu_rng_state() @@ -208,7 +209,7 @@ class Randomizer: index = Randomizer.index() if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] @@ -230,7 +231,7 @@ class Randomizer: if dist.is_initialized(): # convert the index to tensor - index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device()) # all gather the index gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 839e7aab3..7cd24b0ad 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -9,6 +9,7 @@ from typing import Any, Callable, List import torch import torch.multiprocessing as mp from packaging import version +from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count def parameterize(argument: str, values: List[Any]) -> Callable: @@ -198,7 +199,7 @@ def skip_if_not_enough_gpus(min_gpus: int): def _wrap_func(f): def _execute_by_gpu_num(*args, **kwargs): - num_avail_gpu = torch.cuda.device_count() + num_avail_gpu = device_count() if num_avail_gpu >= min_gpus: f(*args, **kwargs) @@ -262,11 +263,11 @@ def clear_cache_before_run(): def _wrap_func(f): def _clear_cache(*args, **kwargs): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_max_memory_cached() - torch.cuda.synchronize() + empty_cache() + reset_peak_memory_stats() + reset_max_memory_allocated() + reset_max_memory_cached() + synchronize() gc.collect() f(*args, **kwargs) diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py index e1bd20d59..c70dbdaa5 100644 --- a/colossalai/utils/device.py +++ b/colossalai/utils/device.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Callable import torch import torch.distributed as dist @@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None: return _dispatch_device_func("reset_max_memory_allocated", device) +def reset_max_memory_cached(device=None) -> None: + return _dispatch_device_func("reset_max_memory_cached", device) + + def memory_reserved(device=None) -> int: return _dispatch_device_func("memory_reserved", device) @@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None: def reset_peak_memory_stats(device=None) -> None: return _dispatch_device_func("reset_peak_memory_stats", device) + + +# amp + + +def autocast() -> Callable: + if torch.cuda.is_available(): + return torch.cuda.amp.autocast() + elif IS_NPU_AVAILABLE: + return torch.npu.amp.autocast() + else: + raise RuntimeError("No device available") diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 1b64363bb..d7a79a022 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -131,7 +131,7 @@ def main(): tp_size=args.tp, pp_size=args.pp, zero_stage=args.zero, - enable_fused_normalization=True, + enable_fused_normalization=torch.cuda.is_available(), num_microbatches=args.mbs, precision="bf16", ) @@ -141,7 +141,7 @@ def main(): pp_size=args.pp, zero_stage=args.zero, cpu_offload=True, - enable_fused_normalization=True, + enable_fused_normalization=torch.cuda.is_available(), num_microbatches=args.mbs, initial_scale=2**8, precision="bf16",