mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[fix\ fix fail case test_shard_llama
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.utils._pytree import tree_flatten, tree_map
|
||||
|
||||
@@ -544,7 +545,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
ctx = optimizer.no_sync()
|
||||
except AttributeError:
|
||||
ctx = model_chunk.no_sync()
|
||||
|
||||
with ctx:
|
||||
optimizer.backward_by_grad(
|
||||
tensor=output_obj_,
|
||||
|
@@ -228,5 +228,4 @@ class PipelineStageManager:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
# print(f"layers_per_stage {layers_per_stage}")
|
||||
return layers_per_stage
|
||||
|
Reference in New Issue
Block a user