diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index 6341e907c..2ce959568 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -29,26 +29,37 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict, partition_states: dict = dict()): src_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) - - if gpc.get_local_rank(parallel_mode) == 0: - - partitioned_state_list = [dict() for _ in range(depth)] - - for key in list(state_dict.keys()): - param = state_dict.pop(key) - dim = dims.get(key, 0) - do_partition = partition_states.get(key, True) - if do_partition: - param = torch.chunk(param, depth, dim=dim) - for i, p in enumerate(partitioned_state_list): - p[key] = param[i] if do_partition else param - - else: - partitioned_state_list = [None for _ in range(depth)] - - partitioned_state = [None] - scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) - return partitioned_state[0] + group = gpc.get_cpu_group(parallel_mode) + is_rank0 = gpc.get_local_rank(parallel_mode) == 0 + partition_info = [None] + if is_rank0: + partition_info_dict = OrderedDict() + for key, param in state_dict.items(): + dim = dims[key] + is_partitioned = partition_states[key] + shape = list(param.shape) + if is_partitioned: + shape[dim] = shape[dim] // depth + partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim) + partition_info[0] = partition_info_dict + dist.broadcast_object_list(partition_info, src_rank, group=group) + partitioned_state = OrderedDict() + for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items(): + if is_partitioned: + output = torch.empty(shape, dtype=dtype) + if is_rank0: + scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)] + else: + scatter_list = None + dist.scatter(output, scatter_list, src_rank, group=group) + else: + if is_rank0: + output = state_dict[key] + else: + output = torch.empty(shape, dtype=dtype) + dist.broadcast(output, src_rank, group=group) + partitioned_state[key] = output + return partitioned_state def gather_tensor_parallel_state_dict(