[feat] Linear1D_COL/ROW support zbv WeightGradStore;

This commit is contained in:
duanjunwen
2024-10-14 07:02:43 +00:00
parent 0ca16d5cbe
commit cfade4c36d
7 changed files with 820 additions and 28 deletions

View File

@@ -11,6 +11,7 @@ from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.pipeline.weight_grad_store import WeightGradStore
from ._utils import (
clone,
@@ -650,10 +651,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
else:
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0
@@ -705,13 +706,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else:
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# # save output_tensor_grad for dw
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # we save loss here
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
# else:
# # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# Step2: bwd step
input_object_grad = self.backward_b_step(
@@ -738,6 +739,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
WeightGradStore.flush(chunk=model_chunk_id)
def schedule_w(
self,
@@ -757,16 +759,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
self.backward_w_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
output_obj=output_obj,
output_obj_grad=output_obj_grad,
)
WeightGradStore.pop(chunk=model_chunk_id)
# self.backward_w_step(
# model_chunk=model_chunk,
# model_chunk_id=model_chunk_id,
# optimizer=optimizer,
# output_obj=output_obj,
# output_obj_grad=output_obj_grad,
# )
def run_forward_only(
self,

View File

@@ -0,0 +1,106 @@
import queue
# from megatron import get_args
# from megatron.core import parallel_state
# from megatron.core.distributed.finalize_model_grads import _allreduce_embedding_grads
# from megatron.core.utils import get_model_config, get_attr_wrapped_model
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")
# @classmethod
# def clear(cls, model, chunk=0):
# weight_grad_tasks = []
# while cls.weight_grad_queue[chunk].qsize() > 0:
# stored_grads = cls.weight_grad_queue[chunk].get()
# if len(weight_grad_tasks) == 0:
# for _ in stored_grads:
# weight_grad_tasks.append([])
# else:
# assert len(weight_grad_tasks) == len(stored_grads)
# for i, task in enumerate(stored_grads):
# weight_grad_tasks[i].append(task)
# weight_params = []
# handles = []
# if get_args().overlap_grad_reduce:
# handles += model.async_reduce_grad()
# output_layer_weight = None
# if parallel_state.is_pipeline_last_stage():
# assert len(weight_grad_tasks) > 0
# output_layer_grads = weight_grad_tasks[0]
# for j in range(len(output_layer_grads)):
# total_input, grad_output, weight, func = output_layer_grads[j]
# if output_layer_weight is None:
# output_layer_weight = weight
# assert output_layer_weight is weight
# func(total_input, grad_output, weight.main_grad)
# output_layer_grads[j] = None # release memory
# weight_grad_tasks = weight_grad_tasks[1:]
# if get_args().overlap_grad_reduce:
# handles += model.async_reduce_grad(output_layer_weight)
# if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
# model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True)
# if model_module.share_embeddings_and_output_weights:
# # if share_embeddings_and_output_weights, wait all-reduce for embeddings
# for handle in handles:
# if handle is not None:
# handle.wait()
# handles = []
# config = get_model_config(model)
# # Do async all-reduce for embedding grads firstly, so that the rank 0 won't
# # be blocked
# embedding_handles = _allreduce_embedding_grads([model], config, async_op=True)
# handles += embedding_handles
# for i in range(len(weight_grad_tasks)):
# tasks = weight_grad_tasks[i]
# param = None
# for j in range(len(tasks)):
# total_input, grad_output, weight, func = tasks[j]
# if param is None:
# param = weight
# assert param is weight
# assert not (weight is output_layer_weight)
# func(total_input, grad_output, weight.main_grad)
# tasks[j] = None # release memory
# weight_params.append(param)
# if get_args().overlap_grad_reduce:
# # All-reduce param grad here
# handles += model.async_reduce_grad(param)
# weight_grad_tasks[i] = None # release memory
# # timers('wait_all_reduce', log_level=1).start(barrier=False)
# for handle in embedding_handles:
# if handle is not None:
# handle.wait()
# # timers('wait_all_reduce').stop()