[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -12,17 +12,9 @@ from torch import autograd, nn, optim
from torch._C._distributed_rpc import PyRRef
from torch.futures import Future
from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.legacy.pipeline.middleware import Partition, Topo
from colossalai.legacy.pipeline.pipeline_process_group import ppg
from colossalai.legacy.pipeline.rpc.utils import (
get_batch_lengths,
pyobj_map,
pytree_filter,
pytree_map,
split_batch,
tensor_shape_list,
type_detail,
)
from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch
class Phase(Enum):
@@ -33,7 +25,7 @@ class Phase(Enum):
class UniqueKey:
__slots__ = ('microbatch_id', 'phase')
__slots__ = ("microbatch_id", "phase")
microbatch_id: int
phase: Phase
@@ -48,12 +40,22 @@ class UniqueKey:
return tuple.__hash__((self.microbatch_id, self.phase))
def __repr__(self) -> str:
return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})'
return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})"
class WorkItem:
__slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id',
'num_microbatches', 'forward_only')
__slots__ = (
"stage_id",
"phase",
"args",
"kwargs",
"output",
"refcount",
"microbatch_id",
"batch_id",
"num_microbatches",
"forward_only",
)
stage_id: int
phase: Phase
@@ -66,50 +68,45 @@ class WorkItem:
num_microbatches: int
forward_only: bool
def __init__(self,
stage_id,
phase,
args,
kwargs,
output,
microbatch_id,
batch_id,
num_microbatches,
forward_only,
refcount=0) -> None:
def __init__(
self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0
) -> None:
for attr_name in self.__slots__:
setattr(self, attr_name, locals()[attr_name])
class BackwardCache:
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
__slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs")
checkpoint: bool
stage_input_args: Tuple[Any]
stage_input_kwargs: Dict[Any, Any]
stage_outputs: Tuple[Any]
def __init__(self,
stage_input_args: Tuple[Any],
stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False) -> None:
def __init__(
self,
stage_input_args: Tuple[Any],
stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False,
) -> None:
for arg_name in self.__slots__:
setattr(self, arg_name, locals()[arg_name])
class WorkerBase(ABC):
def __init__(self,
partition_fn: Callable,
partition_args: tuple,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
def __init__(
self,
partition_fn: Callable,
partition_args: tuple,
pp_rank: int,
actual_stage_num: int,
num_microbatches: int,
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
super().__init__()
self.pp_rank = pp_rank
@@ -150,11 +147,11 @@ class WorkerBase(ABC):
self._initialize_context_container()
# main loop
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True)
self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True)
self.main_loop_thread.start()
def _get_future_by_device(self):
return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device])
return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device])
def _initialize_outstanding_range(self):
outstanding_range = None
@@ -199,12 +196,13 @@ class WorkerBase(ABC):
# lifecycle management for DAG scheduler
if output_work_item.phase == Phase.FORWARD:
lifecycle = len(self.get_consumer_stage_ids())
if self.is_model_output(): # an extra reference for scheduler collecting results
if self.is_model_output(): # an extra reference for scheduler collecting results
lifecycle += 1
elif output_work_item.phase == Phase.BACKWARD:
lifecycle = len(self.get_producer_stage_ids())
if self.is_model_input() and self._is_last_step(
output_work_item): # an extra reference for ensure_backward
output_work_item
): # an extra reference for ensure_backward
lifecycle += 1
else:
lifecycle = 0
@@ -234,9 +232,9 @@ class WorkerBase(ABC):
# offset supports get partial output to reduce comm costs.
def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any:
output = self._get_output_all(key, ref_use, rank)
if offsets is None: # get all for non iterable output
if offsets is None: # get all for non iterable output
return output
else: # get part for iterable output
else: # get part for iterable output
output = [output[i] for i in offsets]
return output
@@ -252,12 +250,12 @@ class WorkerBase(ABC):
def get_partition(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
return self.module_partition
def get_partition_state_dict(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch, merge=False):
@@ -293,8 +291,17 @@ class WorkerBase(ABC):
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
work_item = WorkItem(
self.pp_rank,
Phase.FORWARD,
args,
kwargs,
output,
microbatch_id,
None,
self.num_microbatches,
forward_only,
)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@@ -314,15 +321,25 @@ class WorkerBase(ABC):
for off in self_input_offsets:
self_arg_lst.append(arg_lst[off])
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
self.num_microbatches, forward_only)
work_item = WorkItem(
self.pp_rank,
Phase.FORWARD,
self_arg_lst,
{},
output,
microbatch_id,
None,
self.num_microbatches,
forward_only,
)
with self.work_list_condition_lock:
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
# put input tensor which other nodes need into output_list as Phase.INPUT
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
self.num_microbatches, forward_only)
work_item_remote = WorkItem(
self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only
)
with self.output_list_condition_lock:
self.output_list[recv_input_key] = work_item_remote
@@ -343,8 +360,17 @@ class WorkerBase(ABC):
output = self._get_future_by_device()
grad_wrt_loss = None
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
work_item = WorkItem(
self.pp_rank,
Phase.BACKWARD,
grad_wrt_loss,
{},
output,
microbatch_id,
None,
self.num_microbatches,
False,
)
self.work_list[key] = work_item
self.work_list_condition_lock.notify_all()
@@ -367,7 +393,7 @@ class WorkerBase(ABC):
producer_stage_ids = self.get_producer_stage_ids()
producer_num = len(producer_stage_ids)
if self.need_model_input():
producer_num += 1 # for input partition
producer_num += 1 # for input partition
subscribe_forward_futures: List[Future] = [None] * producer_num
# TODO(jiangziyue) get single value instead of the whole output
@@ -376,9 +402,9 @@ class WorkerBase(ABC):
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
offsets = self._get_input_offsets_by_index(target_index=0)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key,
rank=self.pp_rank,
offsets=offsets)
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets
)
for i in range(0, producer_num - 1):
producer_stage_id = producer_stage_ids[i]
@@ -386,11 +412,12 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i + 1
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
producer_output_key, rank=self.pp_rank, offsets=offsets
)
else:
for i in range(producer_num):
@@ -399,14 +426,24 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
target_index = i
offsets = self._get_input_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
producer_output_key, rank=self.pp_rank, offsets=offsets)
producer_output_key, rank=self.pp_rank, offsets=offsets
)
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
work_item_from_producer = WorkItem(
stage_id,
Phase.FORWARD,
subscribe_forward_futures,
{},
output,
microbatch_id,
None,
self.num_microbatches,
forward_only,
)
return work_item_from_producer
@@ -441,15 +478,25 @@ class WorkerBase(ABC):
consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id]
target_index = i
offsets = self._get_output_offsets_by_index(target_index=target_index)
if offsets is not None and len(offsets) == 0: # no need to do rpc
if offsets is not None and len(offsets) == 0: # no need to do rpc
subscribe_backward_futures[target_index] = []
else:
subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key(
consumer_output_key, rank=self.pp_rank, offsets=offsets)
consumer_output_key, rank=self.pp_rank, offsets=offsets
)
# flatten args
work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output,
microbatch_id, None, self.num_microbatches, False)
work_item_from_consumer = WorkItem(
stage_id,
Phase.BACKWARD,
subscribe_backward_futures,
{},
output,
microbatch_id,
None,
self.num_microbatches,
False,
)
return work_item_from_consumer
@@ -524,8 +571,8 @@ class WorkerBase(ABC):
def get_topo(self):
with self.partition_condition_lock:
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
if hasattr(self.module_partition, '_topo'):
self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition"))
if hasattr(self.module_partition, "_topo"):
return self.module_partition._topo
else:
return None
@@ -564,12 +611,12 @@ class WorkerBase(ABC):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if target_index == src_index:
if output_len == 1:
res = None # offset = None to get all outputs
res = None # offset = None to get all outputs
return res
else:
res.append(src_offset)
@@ -584,7 +631,6 @@ class WorkerBase(ABC):
consumer_stage_ids = self.get_consumer_stage_ids()
for val_list in output_vals:
# An output may be passed to many down stages.
target = None
for val_pos in val_list.get():
dst_partition_id = val_pos.partition_id
dst_offset = val_pos.offset
@@ -597,7 +643,7 @@ class WorkerBase(ABC):
break
if target_index == dst_index:
if input_len == 1:
res = None # offset = None to get all outputs
res = None # offset = None to get all outputs
return res
else:
res.append(dst_offset)
@@ -623,7 +669,7 @@ class WorkerBase(ABC):
flatten_args = []
if self.is_first_stage():
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
else: # get by offset
else: # get by offset
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
self_partition: Partition = topo.get_partition_by_id(self_partition_id)
@@ -652,7 +698,7 @@ class WorkerBase(ABC):
if stage_id == src_stage_id:
src_index += i
break
else: # data from input partition
else: # data from input partition
src_index = 0
# when output_len = 1, not iterable
if output_len == 1:
@@ -679,7 +725,7 @@ class WorkerBase(ABC):
else:
for i, arg in enumerate(args_or_kwargs):
args_or_kwargs[i] = arg.wait()
if args_or_kwargs is not None: # get by offset
if args_or_kwargs is not None: # get by offset
flatten_args = []
topo: Topo = self.get_topo()
self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo)
@@ -719,7 +765,7 @@ class WorkerBase(ABC):
@abstractmethod
def _get_work_item_key(self) -> UniqueKey:
"""
this method control the order of the microbatch to consume
this method control the order of the microbatch to consume
"""
def is_first_stage(self):
@@ -761,7 +807,7 @@ class WorkerBase(ABC):
kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
data_process_func = getattr(self, "data_process_func", self._default_data_process_func)
consume_result = None
is_first_stage = self.is_first_stage()
@@ -787,10 +833,12 @@ class WorkerBase(ABC):
else:
args_kwargs = self._get_real_args_kwargs_fwd(args)
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device,
process_types=torch.device) # change devices from last stage to current device
args_kwargs = pyobj_map(
args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in GPU
args_kwargs = pyobj_map(
args_kwargs, fn=lambda x: self.device, process_types=torch.device
) # change devices from last stage to current device
args, kwargs = data_process_func(args_kwargs)
@@ -851,16 +899,16 @@ class WorkerBase(ABC):
use_checkpoint = False
if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(
stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint
)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
if is_last_stage: # if it is the last stage, trigger backward automatic
self._begin_backward(microbatch_id)
elif phase == Phase.BACKWARD:
@@ -872,7 +920,9 @@ class WorkerBase(ABC):
self.backward_times += 1
self.outstanding -= 1
assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache"
assert (
microbatch_id in self.microbatch_id_to_backward_cache
), f"microbatch_id {microbatch_id} not in backward cache"
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
stage_outputs = backward_cache.stage_outputs
@@ -906,8 +956,9 @@ class WorkerBase(ABC):
filtered_grads.append(grad)
stage_outputs = filtered_outputs
grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
grad_tensors = pyobj_map(
filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in GPU
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# collect grad of input tensor
@@ -920,8 +971,8 @@ class WorkerBase(ABC):
else:
consume_result.append(None)
consume_result = pyobj_map(
consume_result, fn=lambda x: x.to('cpu'),
process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU
consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor
) # torch rpc doesn't support args or rets in GPU
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@@ -929,7 +980,7 @@ class WorkerBase(ABC):
return consume_result
def _get_store_len(self):
return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}'
return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}"
def _get_parameter_grad_sum(self):
grad_sum = 0
@@ -1014,19 +1065,20 @@ class WorkerBase(ABC):
class PipelineEngineBase(ABC, nn.Module):
def __init__(self,
worker_type,
partition_fn: Callable,
stage_num,
num_microbatches,
device: str,
use_1F1B=False,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
def __init__(
self,
worker_type,
partition_fn: Callable,
stage_num,
num_microbatches,
device: str,
use_1F1B=False,
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False,
data_process_func: Callable = None,
) -> None:
super().__init__()
self.worker_type = worker_type
self.partition_fn: Callable = partition_fn
@@ -1056,12 +1108,12 @@ class PipelineEngineBase(ABC, nn.Module):
data_process_func = self.data_process_func
if data_process_func is not None:
assert callable(data_process_func), "data_process_func must be a function"
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
assert "<locals>" not in data_process_func.__repr__(), "data_process_func must be a global function"
assert "<lambda>" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
sig = inspect.signature(data_process_func)
assert len(
sig.parameters
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
assert (
len(sig.parameters) == 2
), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
@@ -1104,19 +1156,33 @@ class PipelineEngineBase(ABC, nn.Module):
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
partition_args = (partition_id, chunk, actual_stage_num)
rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank]
if device[:4] == 'cuda':
device = f'cuda:{rpc_worker_id}'
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id,
worker_type,
args=(partition_fn, partition_args, pp_rank,
actual_stage_num, num_microbatches, device,
criterion, metric, checkpoint, data_process_func))
if device[:4] == "cuda":
device = f"cuda:{rpc_worker_id}"
self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(
rpc_worker_id,
worker_type,
args=(
partition_fn,
partition_args,
pp_rank,
actual_stage_num,
num_microbatches,
device,
criterion,
metric,
checkpoint,
data_process_func,
),
)
# let each worker know global worker rref (include itself)
sync_futs = []
for pp_rank in self.pp_rank_to_worker_rref:
fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
self.pp_rank_to_worker_rref)
fut = (
self.pp_rank_to_worker_rref[pp_rank]
.rpc_async(timeout=0)
.sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
)
sync_futs.append(fut)
for fut in sync_futs:
@@ -1157,8 +1223,9 @@ class PipelineEngineBase(ABC, nn.Module):
def get_output_pp_ranks(self) -> List[int]:
return [self._get_actual_stage_num() - 1]
def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int],
output_pp_ranks: List[int], ret_future):
def _consume_constraint(
self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future
):
actual_stage_num = self._get_actual_stage_num()
use_1F1B = self.use_1F1B
if microbatch_id >= actual_stage_num:
@@ -1206,7 +1273,8 @@ class PipelineEngineBase(ABC, nn.Module):
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD)
fut = worker_rref.rpc_async().get_output_by_key(
key, offsets=[]) # only ensure the res exists, no need for real data.
key, offsets=[]
) # only ensure the res exists, no need for real data.
backward_result.append(fut)
for fut in backward_result:
@@ -1244,11 +1312,14 @@ class PipelineEngineBase(ABC, nn.Module):
if labels is not None and not forward_only:
assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
self, "optimizer_class"
), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches
assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal"
assert (
batch_length >= num_microbatches
), "num_microbatches is greater than the size of a batch, which is illegal"
microbatch_size = math.ceil(batch_length / num_microbatches)
device = self.device
@@ -1285,10 +1356,10 @@ class PipelineEngineBase(ABC, nn.Module):
# collect forward result
forward_result = self._collect_forward_result(output_pp_ranks, ret_future)
if not forward_only and hasattr(self, 'optimizer_class'):
if not forward_only and hasattr(self, "optimizer_class"):
self.step()
self._reset_worker() # reset worker attributes for next batch
self._reset_worker() # reset worker attributes for next batch
return forward_result
def initialize_optimizer(self, optimizer_class: type, **kwargs):