mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-22 07:14:09 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			267 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			267 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy
 | |
| import heapq
 | |
| 
 | |
| 
 | |
| from colossalai.builder import build_model, build_layer
 | |
| from colossalai.context.parallel_mode import ParallelMode
 | |
| from colossalai.core import global_context as gpc
 | |
| from colossalai.logging import get_dist_logger
 | |
| import torch.nn as nn
 | |
| 
 | |
| 
 | |
| def _binary_partition(weights, st, ed):
 | |
|     """Returns the binary partition position of `weights`, given the start
 | |
|     position `st` and the end position `ed`.
 | |
| 
 | |
|     :param weights: A python list to be binary partitioned
 | |
|     :type weights: list
 | |
|     :param st: the start position of the binary partition
 | |
|     :type st: int
 | |
|     :param ed: the end postition of the binary partition
 | |
|     :type ed: int
 | |
|     :return: the binary partition position of `weights`
 | |
|     :rtype: int
 | |
|     """
 | |
|     w_sum = weights[ed - 1]
 | |
|     prefix = 0
 | |
|     if st > 0:
 | |
|         w_sum -= weights[st - 1]
 | |
|         prefix = weights[st - 1]
 | |
|     minimum = float("inf")
 | |
|     for idx in range(st + 1, ed):
 | |
|         front = weights[idx - 1] - prefix
 | |
|         diff = abs(w_sum - 2 * front)
 | |
|         if diff < minimum:
 | |
|             pos = idx
 | |
|             minimum = diff
 | |
| 
 | |
|     return st, pos, ed
 | |
| 
 | |
| 
 | |
| def _heap_addition(weights, intervals, add_cnt):
 | |
|     """
 | |
|     """
 | |
|     def _heap_push(heap, st, ed):
 | |
|         value = weights[ed - 1]
 | |
|         if st > 0:
 | |
|             value -= weights[st - 1]
 | |
|         heapq.heappush(heap, (-value, st, ed))
 | |
| 
 | |
|     ret_intervals = []
 | |
|     heap = []
 | |
| 
 | |
|     for st, ed in intervals:
 | |
|         _heap_push(heap, st, ed)
 | |
| 
 | |
|     while add_cnt > 0:
 | |
|         _, st, ed = heapq.heappop(heap)
 | |
|         if ed - st == 1:
 | |
|             ret_intervals.append((st, ed))
 | |
|         else:
 | |
|             l, m, r = _binary_partition(weights, st, ed)
 | |
|             _heap_push(heap, l, m)
 | |
|             _heap_push(heap, m, r)
 | |
|             add_cnt -= 1
 | |
| 
 | |
|     while heap:
 | |
|         _, st, ed = heapq.heappop(heap)
 | |
|         ret_intervals.append((st, ed))
 | |
| 
 | |
|     ret_intervals.sort()
 | |
|     return ret_intervals
 | |
| 
 | |
| 
 | |
| def _calc_partitions(weights, value):
 | |
|     prev = 0
 | |
|     prefix = 0
 | |
|     num_block = 0
 | |
|     intervals = []
 | |
| 
 | |
|     for idx, w in enumerate(weights):
 | |
|         if weights[idx] - prefix > value:
 | |
|             intervals.append((prev, idx))
 | |
|             prev = idx
 | |
|             prefix = weights[idx - 1]
 | |
|             num_block += 1
 | |
| 
 | |
|     intervals.append((prev, len(weights)))
 | |
|     return num_block + 1, intervals
 | |
| 
 | |
| 
 | |
| def _binary_search(weights, num):
 | |
|     length = len(weights)
 | |
|     prefix = [1 if w == 0 else w for w in weights]
 | |
|     for i in range(1, length):
 | |
|         prefix[i] += prefix[i - 1]
 | |
| 
 | |
|     lower_bound = max(weights)
 | |
|     upper_bound = prefix[length - 1]
 | |
| 
 | |
|     while upper_bound > lower_bound:
 | |
|         mid = (upper_bound + lower_bound) // 2
 | |
|         number, _ = _calc_partitions(prefix, mid)
 | |
|         if number <= num:
 | |
|             upper_bound = mid
 | |
|         else:
 | |
|             lower_bound = mid + 1
 | |
| 
 | |
|     num_block, intervals = _calc_partitions(prefix, upper_bound)
 | |
|     if num_block < num:
 | |
|         intervals = _heap_addition(prefix, intervals, num - num_block)
 | |
| 
 | |
|     return intervals
 | |
| 
 | |
| 
 | |
| def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
 | |
|     assert num_items % num_chunks == 0, \
 | |
|         "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
 | |
| 
 | |
|     logger = get_dist_logger()
 | |
|     parts = [[] for _ in range(pipeline_parallel_size)]
 | |
|     partition_items = num_items // num_chunks
 | |
|     for idx in range(num_chunks):
 | |
|         base_idx = idx * partition_items
 | |
|         chunk_size = partition_items // pipeline_parallel_size
 | |
|         left = pipeline_parallel_size - partition_items % pipeline_parallel_size
 | |
|         if chunk_size == 0:
 | |
|             logger.warning("Some nodes in Pipeline have no requests")
 | |
| 
 | |
|         for p in range(pipeline_parallel_size):
 | |
|             st = base_idx
 | |
|             base_idx += chunk_size + (p >= left)
 | |
|             parts[p].append((st, base_idx))
 | |
| 
 | |
|     return parts
 | |
| 
 | |
| 
 | |
| def partition_balanced(weights, pipeline_parallel_size, num_chunks):
 | |
|     num_total = pipeline_parallel_size * num_chunks
 | |
|     num_items = len(weights)
 | |
|     if num_items <= num_total:
 | |
|         return partition_uniform(num_items, pipeline_parallel_size, num_chunks)
 | |
| 
 | |
|     intervals = _binary_search(weights, num_total)
 | |
| 
 | |
|     current = 0
 | |
|     parts = [[] for _ in range(pipeline_parallel_size)]
 | |
|     for inter in intervals:
 | |
|         parts[current].append(inter)
 | |
|         current = (current + 1) % pipeline_parallel_size
 | |
| 
 | |
|     return parts
 | |
| 
 | |
| 
 | |
| def count_layer_params(layers):
 | |
|     """Count the number of parameters in each layer
 | |
|     """
 | |
|     param_counts = [0] * len(layers)
 | |
|     for idx, cfg in enumerate(layers):
 | |
|         layer = build_layer(cfg)
 | |
|         params = filter(lambda p: p.requires_grad, layer.parameters())
 | |
|         param_counts[idx] = sum(p.numel() for p in params)
 | |
| 
 | |
|     return param_counts
 | |
| 
 | |
| 
 | |
| def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False):
 | |
|     """An intializer to split the model into different stages for pipeline parallelism.
 | |
| 
 | |
|     An example for the model config is shown below. The class VisionTransformerFromConfig should
 | |
|     inherit colossalai.nn.model.ModelFromConfig to allow this initializer to build model from a sequence
 | |
|     of layer configurations.
 | |
| 
 | |
|     model_config = dict(
 | |
|         type='VisionTransformerFromConfig',
 | |
|         embedding_cfg=dict(...),
 | |
|         ...
 | |
|     )
 | |
| 
 | |
|     :param config: Configuration of the model
 | |
|     :type config: dict
 | |
|     :param num_chunks: The number of chunks you want to have on the current stage. This value should be 1
 | |
|                         in most cases unless you are using virutal pipeline parallelism.
 | |
|     :type num_chunks: int, optional
 | |
|     :param partition_method: This parameter determines how you want to split your model layers into stages,
 | |
|                                 you can set it as 'layer' or 'parameter'
 | |
|     :type partition_method: str, optional
 | |
|     :param verbose: Whether to print the logs
 | |
|     :type verbose: bool, optional
 | |
|     """
 | |
|     ori_model = build_model(config)
 | |
|     layers = ori_model.layers_cfg
 | |
|     layer_length = len(layers)
 | |
|     logger = get_dist_logger()
 | |
|     if verbose:
 | |
|         logger.info(f"The total length of layers is {layer_length}", ranks=[0])
 | |
| 
 | |
|     pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
 | |
|     pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
 | |
| 
 | |
|     method = partition_method.lower()
 | |
|     # Make a partition
 | |
|     if method == 'layer':
 | |
|         num_layers = len(layers)
 | |
|         parts = partition_uniform(num_layers, pipeline_parallel_size, num_chunks)
 | |
|     elif method == 'parameter':
 | |
|         param_counts = count_layer_params(layers)
 | |
|         # print_rank_0(param_counts)
 | |
|         parts = partition_balanced(param_counts, pipeline_parallel_size, num_chunks)
 | |
|     else:
 | |
|         raise ValueError("Method should be a pre-set string in [layer, parameter]")
 | |
| 
 | |
|     # Display the partition
 | |
|     if verbose:
 | |
|         log_str = 'Layer allocation after partitioning: \n'
 | |
|         for stage in range(pipeline_parallel_size):
 | |
| 
 | |
|             num_layers = 0
 | |
|             for st, ed in parts[stage]:
 | |
|                 num_layers += ed - st
 | |
| 
 | |
|             log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
 | |
|             for st, ed in parts[stage]:
 | |
|                 for idx, layer in enumerate(layers[st: ed]):
 | |
|                     log_str += f'\t{idx + st:2d}: {layer}\n'
 | |
|         logger.info(log_str, ranks=[0])
 | |
| 
 | |
|     # Save the partition
 | |
|     interval = parts[pipeline_rank]
 | |
| 
 | |
|     models = []
 | |
|     for st, ed in interval:
 | |
|         model = copy.deepcopy(ori_model)
 | |
|         model.build_from_cfg(st, ed)
 | |
|         models.append(model)
 | |
| 
 | |
|     return nn.ModuleList(models) if len(models) > 1 else models[0]
 | |
| 
 | |
| 
 | |
| def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
 | |
|     """An intializer to split the model into different stages for pipeline parallelism.
 | |
|     Note that `layer` must be `torch.nn.Sequential`.
 | |
| 
 | |
|     :param layers: Layers of model
 | |
|     :type layers: `torch.nn.Sequential`
 | |
|     :param num_chunks: The number of chunks you want to have on the current stage. This value should be 1
 | |
|                         in most cases unless you are using virutal pipeline parallelism.
 | |
|     :type num_chunks: int, optional
 | |
|     :param verbose: Whether to print the logs
 | |
|     :type verbose: bool, optional
 | |
|     """
 | |
|     pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
 | |
|     pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
 | |
|     partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
 | |
|     module_list = []
 | |
|     for start, end in partitions[pipeline_rank]:
 | |
|         module_list.append(nn.Sequential(*layers[start:end]))
 | |
|     if verbose:
 | |
|         logger = get_dist_logger()
 | |
|         logger.info(f'Total {len(layers)} layers', ranks=[0])
 | |
|         for rank, part in enumerate(partitions):
 | |
|             log_str = f'===== stage={rank} =====\n'
 | |
|             for chunk, (start, end) in enumerate(part):
 | |
|                 log_str += f'===== chunk={chunk}, layer=[{start}-{end}] =====\n'
 | |
|                 log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
 | |
|             logger.info(log_str, ranks=[0])
 | |
|     return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
 |