[fix\ fix fail case test_shard_llama

This commit is contained in:
duanjunwen
2024-10-25 02:28:55 +00:00
parent 2eca112c90
commit d0ec221b38
5 changed files with 10 additions and 12 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_flatten, tree_map
@@ -544,7 +545,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
ctx = optimizer.no_sync()
except AttributeError:
ctx = model_chunk.no_sync()
with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,