[autoparallel] accelerate gpt2 training (#2495)

This commit is contained in:
YuliangLiu0306
2023-01-29 11:13:15 +08:00
committed by GitHub
parent a360b9bc44
commit aa0f6686f9
5 changed files with 21 additions and 17 deletions

View File

@@ -98,7 +98,7 @@ class DeviceMesh:
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)