From 2ac24040ebe3d6c140db1442fed2d33ff5ceb2ec Mon Sep 17 00:00:00 2001 From: digger yu Date: Tue, 4 Jul 2023 17:53:39 +0800 Subject: [PATCH] fix some typo colossalai/shardformer (#4160) --- colossalai/shardformer/README.md | 2 +- colossalai/shardformer/layer/loss.py | 12 ++++++------ colossalai/shardformer/policies/basepolicy.py | 2 +- colossalai/shardformer/shard/sharder.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index fca401562..6ae32e4fb 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -252,7 +252,7 @@ class ModelSharder: def shard(self) -> None: """ - Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. + Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. """ ... diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 38a5395a0..7e3f6926b 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -48,13 +48,13 @@ class DistCrossEntropy(Function): # [down, up) => false, other device and -100 => true delta = (global_vocab_size + world_size - 1) // world_size - down_shreshold = rank * delta - up_shreshold = down_shreshold + delta - mask = (target < down_shreshold) | (target >= up_shreshold) - masked_target = target.clone() - down_shreshold + down_threshold = rank * delta + up_threshold = down_threshold + delta + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold masked_target[mask] = 0 - # reshape the logist and target + # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] logits_2d = vocab_logits.view(-1, partition_vocab_size) @@ -79,7 +79,7 @@ class DistCrossEntropy(Function): loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) - # caculate the softmax + # calculate the softmax exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) ctx.save_for_backward(exp_logits, mask, masked_target_1d) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 85e6d509c..2d347542f 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -66,7 +66,7 @@ class Policy(ABC): like BertPolicy for Bert Model or OPTPolicy for OPT model. Shardformer has provided many built-in sharding policies for the mainstream models. You can use the - built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`. + built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`. If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. """ diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 2867a0a4f..201e0a08c 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -73,7 +73,7 @@ class ModelSharder(object): layer (torch.nn.Module): The object of layer to shard origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. attr_replacement (Dict): The attribute dict to modify - param_replacement (List[Callable]): The function list to get parameter shard information in polic + param_replacement (List[Callable]): The function list to get parameter shard information in policy sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy """ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \