[fix] rm useless assign and comments;

This commit is contained in:
duanjunwen 2024-08-27 07:31:58 +00:00
parent 1b4bb2beeb
commit 283c9ff5d2
2 changed files with 12 additions and 10 deletions

View File

@ -440,9 +440,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True) torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True)
else: else:
# commom bwd step # commom bwd step
# print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}")
# BUG:output_obj_grad is None # BUG:output_obj_grad is None
# print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n")
torch.autograd.backward( torch.autograd.backward(
tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True
) )
@ -516,7 +514,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = input_obj input_obj = input_obj
else: else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
else: else:
# is last stage; recv from local # is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True): if self.stage_manager.is_last_stage(ignore_chunk=True):
@ -535,8 +532,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs, outputs=outputs,
) )
# print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}")
# add input and output object for backward b # add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj) self.input_tensors[model_chunk_id].append(input_obj)
self.output_tensors[model_chunk_id].append(output_obj) self.output_tensors[model_chunk_id].append(output_obj)
@ -681,7 +676,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
""" """
it = 0 it = 0
# while we still have schedules_node in self.schedules # while we still have schedules_node in self.schedules
# print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n")
while it < len(self.schedules): while it < len(self.schedules):
scheduled_node = self.schedules[it] scheduled_node = self.schedules[it]
print( print(

View File

@ -1,6 +1,7 @@
from copy import deepcopy from copy import deepcopy
from typing import Tuple from typing import Tuple
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -139,7 +140,7 @@ def test_run_fwd_bwd_base(
] ]
scheduler = ZeroBubbleVPipeScheduler( scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule[rank], schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager, stage_manager=stage_manager,
num_model_chunks=pp_size, num_model_chunks=pp_size,
num_microbatch=1, num_microbatch=1,
@ -226,7 +227,6 @@ def test_run_fwd_bwd_base(
# layer 6 # layer 6
assert_close(local_chunk[1].weight, model_base.layers[6].weight) assert_close(local_chunk[1].weight, model_base.layers[6].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad)
if rank == 2: if rank == 2:
# layer 2 # layer 2
assert_close(local_chunk[0].weight, model_base.layers[2].weight) assert_close(local_chunk[0].weight, model_base.layers[2].weight)
@ -234,7 +234,6 @@ def test_run_fwd_bwd_base(
# layer 5 # layer 5
assert_close(local_chunk[1].weight, model_base.layers[5].weight) assert_close(local_chunk[1].weight, model_base.layers[5].weight)
assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad)
if rank == 3: if rank == 3:
# layer 3 # layer 3
assert_close(local_chunk[0].weight, model_base.layers[3].weight) assert_close(local_chunk[0].weight, model_base.layers[3].weight)
@ -244,7 +243,16 @@ def test_run_fwd_bwd_base(
assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad)
# @pytest.mark.dist # Test iter input & multiple microbatch
def test_run_fwd_bwd_iter_input(
rank: int,
world_size: int,
port: int,
):
pass
@pytest.mark.dist
# @pytest.mark.parametrize("num_microbatch", [4]) # @pytest.mark.parametrize("num_microbatch", [4])
# @pytest.mark.parametrize("batch_size", [4]) # @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("num_model_chunk", [2]) # @pytest.mark.parametrize("num_model_chunk", [2])