fix typo examples and docs (#3932)

This commit is contained in:
digger yu
2023-06-08 16:09:32 +08:00
committed by GitHub
parent 407aa48461
commit 33eef714db
8 changed files with 17 additions and 17 deletions

View File

@@ -162,7 +162,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
@@ -173,9 +173,9 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
split_param_col_tp1d(param, pg) # column slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
@@ -237,7 +237,7 @@ def main():
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# asign running configurations
# assign running configurations
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":