diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index 0501d878b..687e220bd 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -18,13 +18,12 @@ class BaseSchedule(ABC): control of FP16 in class schedule. Args: - batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data, - and it will be executed in load_batch. + data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges them into data and label. """ - def __init__(self, batch_data_process_func: Callable = None): + def __init__(self, data_process_func: Callable = None): self.logger = get_dist_logger() - self.batch_data_process_func = batch_data_process_func + self.data_process_func = data_process_func @staticmethod def _move_tensor(element): @@ -34,16 +33,24 @@ class BaseSchedule(ABC): return element def _move_to_device(self, data): - if isinstance(data, dict): + if isinstance(data, torch.Tensor): + data = data.to(get_current_device()) + elif isinstance(data, (list, tuple)): + data = [self._move_tensor(v) for v in data] + elif isinstance(data, dict): data = {k: self._move_tensor(v) for k, v in data.items()} else: - data = self._move_tensor(data) + raise TypeError( + f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") return data - @staticmethod - def _check_sanity(data, tag: str): - assert isinstance(data, (torch.Tensor, dict)), \ - f'{tag} must be torch.Tensor or dict' + def _get_batch_size(self, data): + if isinstance(data, torch.Tensor): + return data.size(0) + elif isinstance(data, (list, tuple)): + return data[0].size(0) + elif isinstance(data, dict): + return data[next(data.keys())].size(0) def load_batch(self, data_iter, to_gpu=True): """Loads a batch from data iterator. It returns the data and labels which are @@ -60,19 +67,10 @@ class BaseSchedule(ABC): raise RuntimeError('Dataloader is not defined.') batch_data = next(data_iter) - if self.batch_data_process_func: - data, label = self.batch_data_process_func(batch_data) - else: - data, label = batch_data - self._check_sanity(data, 'data') - self._check_sanity(label, 'label') - if isinstance(data, torch.Tensor): - self.batch_size = data.size(0) - else: - self.batch_size = next(iter(data.values())).size(0) if to_gpu: - return self._move_to_device(data), self._move_to_device(label) - return data, label + batch_data = self._move_to_device(batch_data) + self.batch_size = self._get_batch_size(batch_data) + return batch_data def pre_processing(self, engine): """To perform actions before running the schedule. @@ -101,8 +99,13 @@ class BaseSchedule(ABC): def _call_engine(engine, inputs): if isinstance(inputs, torch.Tensor): return engine(inputs) - else: + elif isinstance(inputs, (list, tuple)): + return engine(*inputs) + elif isinstance(inputs, dict): return engine(**inputs) + else: + TypeError( + f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}") @staticmethod def _call_engine_criterion(engine, outputs, labels): @@ -112,6 +115,17 @@ class BaseSchedule(ABC): if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(labels, torch.Tensor): - return engine.criterion(*outputs, labels) - else: + labels = (labels,) + + if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)): + return engine.criterion(*outputs, *labels) + elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict): return engine.criterion(*outputs, **labels) + elif isinstance(outputs, dict) and isinstance(labels, dict): + return engine.criterion(**outputs, **labels) + elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): + raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") + else: + raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \ + '(which is auto-converted to tuple), list, tuple, or dict, ' \ + 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)") diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py index d9c081dc8..e6e31a195 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -4,9 +4,10 @@ from typing import Iterable import torch - +import inspect from ._base_schedule import BaseSchedule from colossalai.utils import conditional_context +from typing import Callable class NonPipelineSchedule(BaseSchedule): @@ -16,10 +17,32 @@ class NonPipelineSchedule(BaseSchedule): to update the parameters if it is in training mode. Args: - batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data, + data_process_func (Callable, optional): The preprocessing function which receives a batch of data + and returns a tuple in the form of (data, label). and it will be executed in load_batch. + + Example: + # this shows an example of customized data_process_func + def data_process_func(dataloader_output): + item1, item2, item3 = dataloader_output + data = (item1, item2) + label = item3 + return data, label """ + def __init__(self, data_process_func: Callable = None): + # check that non-pipeline schedule data process func only takes in one parameter + # which is the batch data + + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 1, \ + 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ + 'which is a tuple of tensors for the current batch, ' \ + 'i.e. data_process_func(dataloader_output).' + + super().__init__(data_process_func) + def forward_backward_step(self, engine, data_iter: Iterable, @@ -42,7 +65,14 @@ class NonPipelineSchedule(BaseSchedule): """ assert forward_only or return_loss, \ "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." - data, label = self.load_batch(data_iter) + batch_data = self.load_batch(data_iter) + + if self.batch_data_process_func: + data, label = self.batch_data_process_func(batch_data) + else: + # if not batch data process func is given, + # then we regard the batch data as a simple tuple of (data, label) + data, label = batch_data # forward with conditional_context(torch.no_grad(), enable=forward_only): diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 2b2c4ecbc..bcd91ea6d 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -68,19 +68,41 @@ class PipelineSchedule(BaseSchedule): Args: num_microbatches (int): The number of microbatches. - batch_data_process_func (Callable, optional): + data_process_func (Callable, optional): The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. tensor_shape (torch.Size, optional): Specified shape in pipeline communication. scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + """ def __init__(self, num_microbatches, - batch_data_process_func: Callable = None, + data_process_func: Callable = None, tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, scatter_gather_tensors: bool = False): - super().__init__(batch_data_process_func=batch_data_process_func) + + # we need to make sure that the signature of the data_process_func is valid + if data_process_func: + sig = inspect.signature(data_process_func) + assert len(sig.parameters) == 2, \ + 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ + 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ + 'i.e. data_process_func(stage_output, dataloader_output).' + + super().__init__(data_process_func=data_process_func) assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' @@ -99,29 +121,32 @@ class PipelineSchedule(BaseSchedule): self.scatter_gather_tensors = scatter_gather_tensors self._logger = get_dist_logger() + # cache for the batch data + self.batch_data = None + def load_batch(self, data_iter): # Pipeline schedule just puts data in memory - self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False) + batch_data = super().load_batch(data_iter, to_gpu=False) self.microbatch_offset = 0 - if isinstance(self.batch_data, torch.Tensor): - batch_size = self.batch_data.size(0) - else: - batch_size = next(iter(self.batch_data.values())).size(0) - assert batch_size % self.num_microbatches == 0, \ + assert self.batch_size % self.num_microbatches == 0, \ "Batch size should divided by the number of microbatches" - self.microbatch_size = batch_size // self.num_microbatches + self.microbatch_size = self.batch_size // self.num_microbatches + self.batch_data = batch_data def _get_data_slice(self, data, offset): if isinstance(data, torch.Tensor): return data[offset:offset + self.microbatch_size] + elif isinstance(data, (list, tuple)): + return [val[offset:offset + self.microbatch_size] for val in data] elif isinstance(data, dict): return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") def load_micro_batch(self): - data = self._get_data_slice(self.batch_data, self.microbatch_offset) - label = self._get_data_slice(self.batch_label, self.microbatch_offset) + mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset) self.microbatch_offset += self.microbatch_size - return self._move_to_device(data), self._move_to_device(label) + return self._move_to_device(mciro_batch_data) def pre_processing(self, engine): # TODO: remove this after testing new zero with pipeline parallelism @@ -137,45 +162,78 @@ class PipelineSchedule(BaseSchedule): assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' @staticmethod - def _call_engine(model, input_obj, batch_data): - if isinstance(model, NaiveAMPModel): - sig = inspect.signature(model.model.forward) - elif hasattr(model, 'colo_attr'): - sig = inspect.signature(model.module.forward) + def _call_engine(model, data): + if data is not None: + if isinstance(data, torch.Tensor): + return model(data) + elif isinstance(data, (list, tuple)): + return model(*data) + elif isinstance(data, dict): + return model(**data) + else: + raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + + def _get_actual_forward_func(self, module): + if isinstance(module, NaiveAMPModel): + sig = inspect.signature(module.model.forward) + elif hasattr(module, 'colo_attr'): + sig = inspect.signature(module.module.forward) else: - sig = inspect.signature(model.forward) - if isinstance(batch_data, torch.Tensor): - for p in sig.parameters.values(): - if p.kind == inspect.Parameter.VAR_KEYWORD: - if input_obj is None: - return model(batch_data) + sig = inspect.signature(module.forward) + return sig + + def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model): + if self.data_process_func: + # use customized function to get data and label + data, label = self.data_process_func(stage_output, micro_batch_data) + else: + if isinstance(micro_batch_data, (tuple, list)): + if gpc.is_first_rank(ParallelMode.PIPELINE): + # for the first stage, we use the data from the + # dataloader output by default + data, label = micro_batch_data + else: + # for non-first stage, we use the output passed + # by the previous as the model input + data = stage_output + _, label = micro_batch_data + elif isinstance(micro_batch_data, dict): + args = [] + data = {} + label = {} + + # we feed the stage output to args first + # then map each arg in args to its param name + if stage_output is not None: + if isinstance(stage_output, torch.Tensor): + args.append(stage_output) + elif isinstance(stage_output, (list, tuple)): + args.extend(stage_output) else: - return model(input_obj) - if input_obj is None: - return model(batch_data) - elif isinstance(input_obj, torch.Tensor): - if len(sig.parameters) > 1: - return model(input_obj, batch_data) - else: - return model(input_obj) - else: - if len(sig.parameters) > len(input_obj): - return model(*input_obj, batch_data) - else: - return model(*input_obj) - else: - filter_batch = True - for p in sig.parameters.values(): - if p.kind == inspect.Parameter.VAR_KEYWORD: - filter_batch = False - if filter_batch: - batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters} - if input_obj is None and filter_batch: - return model(**batch_data) - elif isinstance(input_obj, torch.Tensor) or input_obj is None: - return model(input_obj, **batch_data) - else: - return model(*input_obj, **batch_data) + raise TypeError( + f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}" + ) + + # get all parameter names for the forward function of the model + fwd_sig = self._get_actual_forward_func(model) + fwd_sig_param_name = [p.name for p in fwd_sig.values()] + + # build the kwargs for the forward function + for idx, param_name in enumerate(fwd_sig_param_name): + if idx < len(args): + data[param_name] = args[idx] + else: + if param_name in micro_batch_data: + data[param_name] = micro_batch_data[param_name] + + # get the tensors for loss + loss_sig = inspect.signature(criterion) + loss_sig_param_name = [p.name for p in loss_sig.values()] + + for param_name in loss_sig_param_name: + if param_name in micro_batch_data: + label[param_name] = micro_batch_data[param_name] + return data, label def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): """Forward step for passed-in model. If it is the first stage, the input tensor @@ -191,8 +249,11 @@ class PipelineSchedule(BaseSchedule): Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. """ - data, label = self.load_micro_batch() - output_obj = self._call_engine(engine.model, input_obj, data) + micro_batch_data = self.load_micro_batch() + + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model) + + output_obj = self._call_engine(engine.model, data) if gpc.is_last_rank(ParallelMode.PIPELINE): if return_output_label: @@ -399,7 +460,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): def __init__(self, num_microbatches: int, num_model_chunks: int, - batch_data_process_func: Callable = None, + data_process_func: Callable = None, tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, scatter_gather_tensors: bool = False): """A helper schedule class for pipeline parallelism running environment. @@ -409,7 +470,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): Args: num_microbatches (int): The number of microbatches. num_model_chunks (int): The number of model chunks. - batch_data_process_func (Callable, optional): + data_process_func (Callable, optional): The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. tensor_shape (torch.Size, optional): Specified shape in pipeline communication. scatter_gather_tensors (bool, optional): @@ -420,7 +481,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' super().__init__(num_microbatches, - batch_data_process_func=batch_data_process_func, + data_process_func=data_process_func, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather_tensors) gpc.set_virtual_pipeline_parallel_size(num_model_chunks) @@ -446,9 +507,8 @@ class InterleavedPipelineSchedule(PipelineSchedule): def load_micro_batch(self, model_chunk_id): data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id]) - label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id]) self.microbatch_offset[model_chunk_id] += self.microbatch_size - return self._move_to_device(data), self._move_to_device(label) + return self._move_to_device(data) def _forward_step(self, engine, @@ -471,8 +531,11 @@ class InterleavedPipelineSchedule(PipelineSchedule): Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. """ - data, label = self.load_micro_batch(model_chunk_id) - output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data) + micro_batch_data = self.load_micro_batch(model_chunk_id) + data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, + engine.model[model_chunk_id]) + + output_obj = self._call_engine(engine.model[model_chunk_id], data) if gpc.is_pipeline_last_stage(): if return_output_label: