mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix\ fix fail case test_shard_llama
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
@@ -544,7 +545,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
ctx = optimizer.no_sync()
|
||||
except AttributeError:
|
||||
ctx = model_chunk.no_sync()
|
||||
|
||||
with ctx:
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
|
Reference in New Issue
Block a user