mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[fix] fix handle name; rm useless comments;
This commit is contained in:
@@ -107,7 +107,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
self.local_send_backward_buffer = []
|
||||
|
||||
# wait pp buffer
|
||||
self.send_handles = []
|
||||
self.wait_handles = []
|
||||
|
||||
def assert_buffer_empty(self):
|
||||
# assert buffer is empty at end
|
||||
@@ -129,7 +129,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
assert len(self.recv_backward_buffer[1]) == 0
|
||||
assert len(self.local_send_forward_buffer) == 0
|
||||
assert len(self.local_send_backward_buffer) == 0
|
||||
# assert len(self.send_handles) == 0
|
||||
|
||||
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
|
||||
"""Load a batch from data iterator.
|
||||
@@ -891,7 +890,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
self.send_handles.append(wait_handle)
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
@@ -915,7 +914,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
for h in self.send_handles:
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
|
||||
|
@@ -1,7 +1,5 @@
|
||||
import queue
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
@@ -32,52 +30,3 @@ class WeightGradStore:
|
||||
weight.grad = grad_weight
|
||||
else:
|
||||
raise Exception("Pop empty queue.")
|
||||
|
||||
@classmethod
|
||||
def clear(cls, stage_manager: PipelineStageManager, chunk=0):
|
||||
pass
|
||||
# print(f"stage {stage_manager.stage} len_chunk_0 {cls.weight_grad_queue[0].qsize()} len_chunk_1 {cls.weight_grad_queue[1].qsize()}")
|
||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
# for total_input, grad_output, weight, func in stored_grads:
|
||||
# if weight.grad is not None:
|
||||
# func(total_input, grad_output, weight.grad)
|
||||
# # for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
# else:
|
||||
# grad_weight = func(total_input, grad_output)
|
||||
# weight.grad = grad_weight
|
||||
|
||||
# weight_grad_tasks = []
|
||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
# if len(weight_grad_tasks) == 0:
|
||||
# for _ in stored_grads:
|
||||
# weight_grad_tasks.append([])
|
||||
# else:
|
||||
# assert len(weight_grad_tasks) == len(stored_grads)
|
||||
# for i, task in enumerate(stored_grads):
|
||||
# weight_grad_tasks[i].append(task)
|
||||
|
||||
# if stage_manager.is_last_stage(ignore_chunk=True) and chunk == 1:
|
||||
# assert len(weight_grad_tasks) > 0
|
||||
# output_layer_grads = weight_grad_tasks[0]
|
||||
# for j in range(len(output_layer_grads)):
|
||||
# total_input, grad_output, weight, func = output_layer_grads[j]
|
||||
# if output_layer_weight is None:
|
||||
# output_layer_weight = weight
|
||||
# assert output_layer_weight is weight
|
||||
# func(total_input, grad_output, weight.grad)
|
||||
# output_layer_grads[j] = None # release memory
|
||||
# weight_grad_tasks = weight_grad_tasks[1:]
|
||||
|
||||
# for i in range(len(weight_grad_tasks)):
|
||||
# tasks = weight_grad_tasks[i]
|
||||
# param = None
|
||||
# for j in range(len(tasks)):
|
||||
# total_input, grad_output, weight, func = tasks[j]
|
||||
# if param is None:
|
||||
# param = weight
|
||||
# assert param is weight
|
||||
# func(total_input, grad_output, weight.grad)
|
||||
# tasks[j] = None # release memory
|
||||
# weight_grad_tasks[i] = None # release memory
|
||||
|
@@ -60,10 +60,7 @@ class LlamaPolicy(Policy):
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
@@ -96,7 +93,6 @@ class LlamaPolicy(Policy):
|
||||
target_key=attn_cls,
|
||||
)
|
||||
|
||||
# if self.pipeline_stage_manager is not None:
|
||||
if self.pipeline_stage_manager is None:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
@@ -410,20 +406,6 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
# if self.pipeline_stage_manager.use_zbv:
|
||||
# return [
|
||||
# {
|
||||
# 0: llama_model.embed_tokens.weight,
|
||||
# 0: self.model.lm_head.weight,
|
||||
# }
|
||||
# ]
|
||||
# else:
|
||||
# return [
|
||||
# {
|
||||
# 0: llama_model.embed_tokens.weight,
|
||||
# self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
# }
|
||||
# ]
|
||||
return []
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user