mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[rpc] split with dag (#2028)
* add DAG to split_module * add comment * add test case for DAG * remove print * add DAG middleware in scheduler * add test case for scheduler * remove break * recover old lifecycle Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -117,7 +117,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
|
||||
return gm
|
||||
|
||||
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output=False):
|
||||
# TODO(lyl): use partition IR to assign partition ID to each node.
|
||||
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
|
||||
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
|
||||
@@ -129,7 +129,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
|
||||
part_idx += 1
|
||||
return part_idx
|
||||
|
||||
split_mod = split_module(annotated_gm, None, split_callback)
|
||||
split_mod = split_module(annotated_gm, None, split_callback, merge_output)
|
||||
split_submodules = []
|
||||
for name, submodule in split_mod.named_modules():
|
||||
if isinstance(submodule, torch.fx.GraphModule):
|
||||
|
@@ -199,24 +199,17 @@ def find_user_in_partition(node, partitions, output_partitions=None, direct=Fals
|
||||
for partition in partitions:
|
||||
if node == partition:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
# find user with getitem call
|
||||
else:
|
||||
for partition in partitions:
|
||||
if node in partition.args:
|
||||
user_partition_names.append(partition.name)
|
||||
|
||||
is_output = False
|
||||
def find_output(def_node, output_node):
|
||||
nonlocal is_output
|
||||
if def_node == output_node:
|
||||
is_output = True
|
||||
|
||||
|
||||
if output_partitions is not None:
|
||||
output_node = output_partitions[0]
|
||||
torch.fx.graph.map_arg(output_node.args[0], lambda n: find_output(node, n))
|
||||
|
||||
if is_output:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
if node.op == output_node.op:
|
||||
user_partition_names.append('MODEL_OUTPUT')
|
||||
|
||||
if len(user_partition_names) > 0:
|
||||
return user_partition_names
|
||||
|
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, get_real_args_kwargs, pytree_filter, pytree_map,
|
||||
from colossalai.pipeline.rpc.utils import (get_batch_lengths, pytree_filter, pytree_map,
|
||||
split_batch, tensor_shape_list, type_detail)
|
||||
from torch import autograd, nn, optim
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
@@ -20,7 +20,7 @@ class Phase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
UPDATE = 2
|
||||
|
||||
INPUT = 3
|
||||
|
||||
class UniqueKey:
|
||||
__slots__ = ('microbatch_id', 'phase')
|
||||
@@ -128,6 +128,7 @@ class WorkerBase(ABC):
|
||||
# topology info
|
||||
self.producer_stage_ids: List[int] = None
|
||||
self.consumer_stage_ids: List[int] = None
|
||||
self.input_consumer_stage_ids: List[int] = None
|
||||
|
||||
# module partitions
|
||||
self.partition_fn = partition_fn
|
||||
@@ -135,6 +136,11 @@ class WorkerBase(ABC):
|
||||
self.criterion = criterion
|
||||
self.metric = metric
|
||||
|
||||
# middleware info
|
||||
self._is_input = False
|
||||
self._is_output = False
|
||||
self._producer_consumer_initialized = False
|
||||
|
||||
# context to maintain loop
|
||||
self._initialize_context_container()
|
||||
|
||||
@@ -164,6 +170,7 @@ class WorkerBase(ABC):
|
||||
self.work_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.output_list_condition_lock = threading.Condition(threading.Lock())
|
||||
self.label_lock = threading.Condition(threading.Lock())
|
||||
self.producer_consumer_init_lock = threading.Condition(threading.Lock())
|
||||
|
||||
def _initialize_partition(self):
|
||||
partition_fn = self.partition_fn
|
||||
@@ -182,7 +189,7 @@ class WorkerBase(ABC):
|
||||
# construction of partition is executed after the registion of pp_rank_to_worker_rref
|
||||
self._initialize_partition()
|
||||
|
||||
def get_output_by_key(self, key: UniqueKey) -> Any:
|
||||
def get_output_by_key(self, key: UniqueKey, recv_rank=None) -> Any:
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list_condition_lock.wait_for(lambda: key in self.output_list)
|
||||
output_work_item = self.output_list[key]
|
||||
@@ -191,8 +198,9 @@ class WorkerBase(ABC):
|
||||
if isinstance(output, Future):
|
||||
output = output.wait()
|
||||
|
||||
output_work_item.refcount += 1
|
||||
# output_work_item.refcount += 1
|
||||
|
||||
# TODO(jiangziyue) redesign lifecycle management for DAG scheduler
|
||||
# all consumers have been satisfied, the work_item can be released
|
||||
with self.output_list_condition_lock:
|
||||
if output_work_item.refcount >= len(self.consumer_stage_ids):
|
||||
@@ -215,8 +223,10 @@ class WorkerBase(ABC):
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
return self.module_partition.state_dict()
|
||||
|
||||
def _make_args_kwargs(self, microbatch):
|
||||
def _make_args_kwargs(self, microbatch, merge=False):
|
||||
if isinstance(microbatch, dict):
|
||||
if merge:
|
||||
return list(microbatch.values()), {}
|
||||
return [], microbatch
|
||||
elif isinstance(microbatch, torch.Tensor):
|
||||
return [microbatch], {}
|
||||
@@ -228,24 +238,70 @@ class WorkerBase(ABC):
|
||||
kwargs.update(arg)
|
||||
else:
|
||||
args.append(arg)
|
||||
if merge:
|
||||
arg_lst = args
|
||||
for arg in kwargs.values():
|
||||
arg_lst.append(arg)
|
||||
return arg_lst, {}
|
||||
return args, kwargs
|
||||
else:
|
||||
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
|
||||
|
||||
# just for first pp_rank
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||
assert self.consumer_stage_ids is not None
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
output = self._get_future_by_device()
|
||||
|
||||
if not self.use_middleware():
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
else:
|
||||
# make args and kwargs
|
||||
arg_lst, _ = self._make_args_kwargs(microbatch, merge=True)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
# first stage assign correct input into other stages
|
||||
DAG = self.get_DAG()
|
||||
DAG_node = DAG['input_partition']
|
||||
self_input_offsets = []
|
||||
recv_input_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
# notify rank which should receive extra input
|
||||
offset = 0
|
||||
for details in DAG_node.values():
|
||||
for partition_name in details['output'].keys():
|
||||
recv_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
if recv_rank == self.pp_rank:
|
||||
self_input_offsets.append(offset)
|
||||
elif recv_rank not in self.input_consumer_stage_ids:
|
||||
self.input_consumer_stage_ids.append(recv_rank)
|
||||
offset += 1
|
||||
|
||||
# set input for self rank
|
||||
self_arg_lst = []
|
||||
for off in self_input_offsets:
|
||||
self_arg_lst.append(arg_lst[off])
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# put input tensor which other nodes need into output_list
|
||||
work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
self.output_list[recv_input_key] = work_item_remote
|
||||
self.output_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
def set_labels(self, microbatch_id: int, microlabels: Any):
|
||||
@@ -268,33 +324,68 @@ class WorkerBase(ABC):
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
def subscribe_producer(self, microbatch_id: int, forward_only: bool):
|
||||
"""
|
||||
You should call this function asynchronously
|
||||
"""
|
||||
assert self.producer_stage_ids is not None
|
||||
producer_num = len(self.producer_stage_ids)
|
||||
assert producer_num > 0, "only stage that has producers can subscribe producers"
|
||||
|
||||
stage_id = self.pp_rank
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
output = self._get_future_by_device()
|
||||
if not self.use_middleware():
|
||||
producer_num = len(self.producer_stage_ids)
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = self.producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
else:
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
if key in self.work_list:
|
||||
return
|
||||
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = self.producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
producer_stage_ids = []
|
||||
with self.producer_consumer_init_lock:
|
||||
self.producer_consumer_init_lock.wait_for(lambda: self._producer_consumer_initialized)
|
||||
producer_stage_ids = self.producer_stage_ids
|
||||
producer_num = len(producer_stage_ids)
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
if self.need_model_input():
|
||||
producer_num += 1 # extra one(the last one) for input_tensor
|
||||
subscribe_forward_futures: List[Future] = [None] * producer_num
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
if self.need_model_input():
|
||||
producer_stage_id = 0
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.INPUT)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
for i in range(0, producer_num-1):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i+1] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
else:
|
||||
for i in range(producer_num):
|
||||
producer_stage_id = producer_stage_ids[i]
|
||||
producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
#producer_partition_name = self.pp_rank_to_partition_name[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, self.pp_rank)
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
||||
# add work_item to work_list
|
||||
with self.work_list_condition_lock:
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_producer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
if key not in self.work_list:
|
||||
self.work_list[key] = work_item_from_producer
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
"""
|
||||
@@ -334,13 +425,132 @@ class WorkerBase(ABC):
|
||||
self.producer_stage_ids = []
|
||||
self.consumer_stage_ids = []
|
||||
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= 0:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
if not self.use_middleware():
|
||||
# Just for demo
|
||||
prev_rank = rank - 1
|
||||
next_rank = rank + 1
|
||||
if prev_rank >= 0:
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
if next_rank <= self.actual_stage_num - 1:
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
else:
|
||||
self.input_consumer_stage_ids = []
|
||||
DAG = self.get_DAG()
|
||||
DAG_node_name = self.pp_rank_to_partition_name(rank)
|
||||
DAG_node = DAG[DAG_node_name]
|
||||
for partition_name in DAG_node['input'].keys():
|
||||
if partition_name == 'MODEL_INPUT':
|
||||
self._is_input = True
|
||||
else:
|
||||
prev_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
self.producer_stage_ids.append(prev_rank)
|
||||
|
||||
for partition_name in DAG_node['output'].keys():
|
||||
if partition_name == 'MODEL_OUTPUT':
|
||||
self._is_output = True
|
||||
else:
|
||||
next_rank = self.partition_name_to_pp_rank(partition_name)
|
||||
self.consumer_stage_ids.append(next_rank)
|
||||
|
||||
# TODO(jiangziyue) Consider whether this function should be protected by Lock in DAG env.
|
||||
with self.producer_consumer_init_lock:
|
||||
self._producer_consumer_initialized = True
|
||||
self.producer_consumer_init_lock.notify_all()
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def pp_rank_to_partition_name(self, pp_rank: int):
|
||||
prefix = 'submod_'
|
||||
partition_name = prefix + str(pp_rank)
|
||||
return partition_name
|
||||
|
||||
# TODO(jiangziyue) Define a Class for DAG.
|
||||
def partition_name_to_pp_rank(self, partition_name: str) -> int:
|
||||
prefix = 'submod_'
|
||||
pp_rank = int(partition_name.split(prefix)[-1])
|
||||
return pp_rank
|
||||
|
||||
def get_DAG(self):
|
||||
with self.partition_condition_lock:
|
||||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
if hasattr(self.module_partition, '_DAG'):
|
||||
return self.module_partition._DAG
|
||||
else:
|
||||
return None
|
||||
|
||||
def use_middleware(self):
|
||||
DAG = self.get_DAG()
|
||||
return DAG is not None
|
||||
|
||||
# TODO(jiangziyue) get single value instead of the whole output
|
||||
def _get_real_args_kwargs(self, args_or_kwargs):
|
||||
if not self.use_middleware():
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
args_or_kwargs = flatten_args
|
||||
else:
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
if self.is_first_stage():
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
# TODO get by offset
|
||||
else:
|
||||
DAG = self.get_DAG()
|
||||
producer_outputs = {}
|
||||
cur_DAG_node_name = self.pp_rank_to_partition_name(self.pp_rank)
|
||||
#cur_DAG_node = DAG[self.pp_rank_to_partition_name(self.pp_rank)]
|
||||
for i, args_from_one_mod in enumerate(args_or_kwargs):
|
||||
producer_output_offsets = []
|
||||
if self.need_model_input():
|
||||
if i == 0:
|
||||
producer_DAG_node = DAG['input_partition']
|
||||
producer_partition_name = 'MODEL_INPUT'
|
||||
offset = 0
|
||||
for arg_info in producer_DAG_node.values():
|
||||
if cur_DAG_node_name in arg_info['output']:
|
||||
producer_output_offsets.append(offset)
|
||||
offset += 1
|
||||
else:
|
||||
producer_rank = self.producer_stage_ids[i-1]
|
||||
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||
producer_DAG_node = DAG[producer_partition_name]
|
||||
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||
|
||||
else:
|
||||
producer_rank = self.producer_stage_ids[i]
|
||||
producer_partition_name = self.pp_rank_to_partition_name(producer_rank)
|
||||
producer_DAG_node = DAG[producer_partition_name]
|
||||
producer_output_offsets = producer_DAG_node['output'][cur_DAG_node_name]
|
||||
|
||||
if producer_partition_name != 'MODEL_INPUT' and DAG[producer_partition_name]['output_len'] == 1:
|
||||
producer_outputs[producer_partition_name] = [args_from_one_mod]
|
||||
else:
|
||||
producer_outputs[producer_partition_name] = [args_from_one_mod[offset] for offset in producer_output_offsets]
|
||||
|
||||
cur_DAG_node_input = DAG[cur_DAG_node_name]['input']
|
||||
|
||||
def get_input_len(DAG_node_input):
|
||||
res = 0
|
||||
for offsets in DAG_node_input.values():
|
||||
res += len(offsets)
|
||||
return res
|
||||
|
||||
input_len = get_input_len(cur_DAG_node_input)
|
||||
flatten_args = [None] * input_len
|
||||
for producer_partition_name, args_input_offsets in cur_DAG_node_input.items():
|
||||
for i, args_input_offset in enumerate(args_input_offsets):
|
||||
flatten_args[args_input_offset] = producer_outputs[producer_partition_name][i]
|
||||
|
||||
args_or_kwargs = flatten_args
|
||||
return args_or_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _get_work_item_key(self) -> UniqueKey:
|
||||
@@ -353,6 +563,9 @@ class WorkerBase(ABC):
|
||||
|
||||
def is_last_stage(self):
|
||||
return self.pp_rank == self.actual_stage_num - 1
|
||||
|
||||
def need_model_input(self):
|
||||
return not self.is_first_stage() and self._is_input
|
||||
|
||||
def _default_data_process_func(self, args_kwargs):
|
||||
if self.is_first_stage():
|
||||
@@ -390,11 +603,11 @@ class WorkerBase(ABC):
|
||||
|
||||
# parse and integrate args and kwargs
|
||||
if is_first_stage:
|
||||
args = get_real_args_kwargs(args)
|
||||
kwargs = get_real_args_kwargs(kwargs)
|
||||
args = self._get_real_args_kwargs(args)
|
||||
kwargs = self._get_real_args_kwargs(kwargs)
|
||||
args_kwargs = (args, kwargs)
|
||||
else:
|
||||
args_kwargs = get_real_args_kwargs(args)
|
||||
args_kwargs = self._get_real_args_kwargs(args)
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
@@ -486,7 +699,7 @@ class WorkerBase(ABC):
|
||||
|
||||
# overlap recompute and future.wait
|
||||
if not is_last_stage:
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
grad_tensors = self._get_real_args_kwargs(args)
|
||||
else:
|
||||
grad_tensors = None
|
||||
|
||||
@@ -569,7 +782,10 @@ class WorkerBase(ABC):
|
||||
self._reset_context()
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
# After tracer fixed, remove this part.
|
||||
if len(list(self.module_partition.parameters())) > 0:
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
self.step_lock = threading.Lock()
|
||||
self.step_lock.acquire()
|
||||
|
||||
@@ -577,8 +793,11 @@ class WorkerBase(ABC):
|
||||
self.step_lock.acquire()
|
||||
|
||||
def step(self):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# TODO(jiangziyue) it's temporary code to deal with empty module partition.
|
||||
# After tracer fixed, remove this part.
|
||||
if len(list(self.module_partition.parameters())) > 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.step_lock.release()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user