Merge pull request #6071 from wangbluo/ring_attention

[Ring Attention] fix the 2d ring attn when using multiple machine
This commit is contained in:
Wang Binluo
2024-10-15 15:17:21 +08:00
committed by GitHub
6 changed files with 41 additions and 31 deletions

View File

@@ -1177,7 +1177,10 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,