[fix] fix missing reduce variable

This commit is contained in:
hxwang 2024-05-29 02:09:14 +00:00
parent b5ae587d50
commit fee35678e5

View File

@ -422,8 +422,8 @@ class GeminiDDP(ModelWrapper):
grad_chunk.add_tensor_to_chunk_slice(p, grad)
with torch.cuda.stream(async_reduce_stream):
chunk_manager.reduce_chunk(grad_chunk)
reduced = chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)