mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
fix typo examples and docs (#3932)
This commit is contained in:
@@ -162,7 +162,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
@@ -173,9 +173,9 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
split_param_col_tp1d(param, pg) # column slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
@@ -237,7 +237,7 @@ def main():
|
||||
if args.tp_degree > 1:
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# asign running configurations
|
||||
# assign running configurations
|
||||
if args.distplan == "CAI_ZeRO1":
|
||||
zero_stage = 1
|
||||
elif args.distplan == "CAI_ZeRO2":
|
||||
|
Reference in New Issue
Block a user