mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d * update * fix autocast
This commit is contained in:
@@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
@@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
|
||||
self.mixed_precision = torch.bfloat16
|
||||
if self.mixed_precision is not None:
|
||||
module = module.to(self.mixed_precision)
|
||||
module = module.cuda()
|
||||
module = module.to(get_current_device())
|
||||
|
||||
# setting input type cast when using mixed precision
|
||||
self.convert_fn = None
|
||||
@@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if self.pp_size > 1:
|
||||
@@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
@@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
@@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
@@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
@@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
@@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
return self.pp_size > 1
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda"]
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ["fp16", "bf16", "fp32"]
|
||||
|
Reference in New Issue
Block a user