From f1f51990b94ac5acec7c213603cf5d832046575f Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 21 Jun 2022 22:46:30 +0800 Subject: [PATCH] [hotfix]fix some bugs caused by refactored schedule. (#1148) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [hotfix]fix some bugs caused by refactored schedule. --- colossalai/engine/schedule/_base_schedule.py | 8 +++++++- colossalai/engine/schedule/_non_pipeline_schedule.py | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index a144db6a0..b30aff784 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -36,7 +36,13 @@ class BaseSchedule(ABC): if isinstance(data, torch.Tensor): data = data.to(get_current_device()) elif isinstance(data, (list, tuple)): - data = [self._move_tensor(v) for v in data] + data_to_return = [] + for element in data: + if isinstance(element, dict): + data_to_return.append({k: self._move_tensor(v) for k, v in element.items()}) + else: + data_to_return.append(self._move_tensor(element)) + data = data_to_return elif isinstance(data, dict): data = {k: self._move_tensor(v) for k, v in data.items()} else: diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py index 8e41df53b..c62bfb7d7 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -66,7 +66,6 @@ class NonPipelineSchedule(BaseSchedule): assert forward_only or return_loss, \ "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." batch_data = self.load_batch(data_iter) - if self.data_process_func: data, label = self.data_process_func(batch_data) else: