mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	[zero] find miss code (#378)
This commit is contained in:
		@@ -4,10 +4,6 @@ repos:
 | 
				
			|||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
    - id: yapf
 | 
					    - id: yapf
 | 
				
			||||||
      args: ['--style=.style.yapf', '--parallel', '--in-place']
 | 
					      args: ['--style=.style.yapf', '--parallel', '--in-place']
 | 
				
			||||||
  - repo: https://github.com/pycqa/flake8
 | 
					 | 
				
			||||||
    rev: '4.0.1'
 | 
					 | 
				
			||||||
    hooks:
 | 
					 | 
				
			||||||
    - id: flake8
 | 
					 | 
				
			||||||
  - repo: https://github.com/pre-commit/mirrors-clang-format
 | 
					  - repo: https://github.com/pre-commit/mirrors-clang-format
 | 
				
			||||||
    rev: v13.0.1
 | 
					    rev: v13.0.1
 | 
				
			||||||
    hooks:
 | 
					    hooks:
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										3
									
								
								colossalai/utils/commons/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								colossalai/utils/commons/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					from .bucket_tensor_copy import BucketizedTensorCopy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = ['BucketizedTensorCopy']
 | 
				
			||||||
							
								
								
									
										61
									
								
								colossalai/utils/commons/bucket_tensor_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								colossalai/utils/commons/bucket_tensor_copy.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from colossalai.zero.sharded_param import ShardedParamV2
 | 
				
			||||||
 | 
					from colossalai.utils import get_current_device
 | 
				
			||||||
 | 
					from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BucketizedTensorCopy(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        chunk_size: int,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        r""" 
 | 
				
			||||||
 | 
					        torch.nn.Parameter CPU (fp32) -> ShardedParam GPU (fp16)
 | 
				
			||||||
 | 
					        TODO(jiaruifang) The class is a little bit hardcoded
 | 
				
			||||||
 | 
					        I will make it more general later.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.chunk_size = chunk_size
 | 
				
			||||||
 | 
					        self._offset = 0
 | 
				
			||||||
 | 
					        self._cpu_buffer = torch.empty(chunk_size, dtype=torch.float, device=torch.device("cpu:0"), pin_memory=True)
 | 
				
			||||||
 | 
					        self._cuda_buffer = torch.empty(chunk_size,
 | 
				
			||||||
 | 
					                                        dtype=torch.half,
 | 
				
			||||||
 | 
					                                        device=torch.device(f"cuda:{get_current_device()}"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._buffered_param_list: List[ShardedParamV2] = []
 | 
				
			||||||
 | 
					        self._numel_list = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def copy(self, src_param: torch.nn.Parameter, target_param: ShardedParamV2):
 | 
				
			||||||
 | 
					        assert isinstance(target_param, ShardedParamV2)
 | 
				
			||||||
 | 
					        assert isinstance(src_param, torch.nn.Parameter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        numel = src_param.numel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self._offset + numel > self.chunk_size:
 | 
				
			||||||
 | 
					            self.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert src_param.data.device.type == 'cpu'
 | 
				
			||||||
 | 
					        self._cpu_buffer.narrow(0, self._offset, numel).copy_(src_param.data.view(-1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._buffered_param_list.append(target_param)
 | 
				
			||||||
 | 
					        self._numel_list.append(numel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._offset += numel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def flush(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        flush to cuda memory
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self._cuda_buffer.copy_(self._cpu_buffer)
 | 
				
			||||||
 | 
					        flush_offset = 0
 | 
				
			||||||
 | 
					        for sparam, numel in zip(self._buffered_param_list, self._numel_list):
 | 
				
			||||||
 | 
					            sparam.data.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel))
 | 
				
			||||||
 | 
					            flush_offset += numel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reset(self):
 | 
				
			||||||
 | 
					        self._buffered_param_list = []
 | 
				
			||||||
 | 
					        self._numel_list = []
 | 
				
			||||||
 | 
					        self._offset = 0
 | 
				
			||||||
@@ -88,14 +88,19 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
 | 
				
			|||||||
            self.zero_grad()
 | 
					            self.zero_grad()
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Write master param to p.data
 | 
					        # assign master param pointers to p.data.
 | 
				
			||||||
 | 
					        # We will not trigger data copy here.
 | 
				
			||||||
        for group in self.optim.param_groups:
 | 
					        for group in self.optim.param_groups:
 | 
				
			||||||
            for p in group['params']:
 | 
					            for p in group['params']:
 | 
				
			||||||
                p.data = self.master_params[p]
 | 
					                p.data = self.master_params[p]
 | 
				
			||||||
                # Now p.data is sharded
 | 
					                # Now p.data is sharded
 | 
				
			||||||
                # So optimizer states are sharded naturally
 | 
					                # So optimizer states are sharded naturally
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ret = self.optim.step(*args, **kwargs)
 | 
					        ret = self.optim.step(*args, **kwargs)
 | 
				
			||||||
        # Write master param to payload
 | 
					
 | 
				
			||||||
 | 
					        # Copy master param data (fp32) to payload of col_attr (fp16)
 | 
				
			||||||
 | 
					        # TODO() improve efficiency by gathering tensors into a chunk and transfering
 | 
				
			||||||
 | 
					        # a chunk.
 | 
				
			||||||
        for group in self.optim.param_groups:
 | 
					        for group in self.optim.param_groups:
 | 
				
			||||||
            for p in group['params']:
 | 
					            for p in group['params']:
 | 
				
			||||||
                is_param_sharded = p.col_attr.data.is_sharded
 | 
					                is_param_sharded = p.col_attr.data.is_sharded
 | 
				
			||||||
@@ -108,7 +113,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
 | 
				
			|||||||
                    self.shard_strategy.shard([p.col_attr.data])
 | 
					                    self.shard_strategy.shard([p.col_attr.data])
 | 
				
			||||||
                # We have to use `copy_payload` instead of `reset_payload`
 | 
					                # We have to use `copy_payload` instead of `reset_payload`
 | 
				
			||||||
                # Since p.data is fp32 and p.col_attr.data is fp16
 | 
					                # Since p.data is fp32 and p.col_attr.data is fp16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # TODO() optimize this line
 | 
				
			||||||
                p.col_attr.data.copy_payload(p.data)
 | 
					                p.col_attr.data.copy_payload(p.data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if not is_param_sharded:
 | 
					                if not is_param_sharded:
 | 
				
			||||||
                    # We gather full fp16 param here
 | 
					                    # We gather full fp16 param here
 | 
				
			||||||
                    self.shard_strategy.gather([p.col_attr.data])
 | 
					                    self.shard_strategy.gather([p.col_attr.data])
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,7 +14,6 @@ class ShardedTensor(object):
 | 
				
			|||||||
        self.world_size = dist.get_world_size(self.process_group)
 | 
					        self.world_size = dist.get_world_size(self.process_group)
 | 
				
			||||||
        self.local_rank = dist.get_rank(self.process_group)
 | 
					        self.local_rank = dist.get_rank(self.process_group)
 | 
				
			||||||
        self._is_sharded = False
 | 
					        self._is_sharded = False
 | 
				
			||||||
        self._payload = tensor
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._origin_shape = tensor.shape
 | 
					        self._origin_shape = tensor.shape
 | 
				
			||||||
        self._origin_numel = tensor.numel()
 | 
					        self._origin_numel = tensor.numel()
 | 
				
			||||||
@@ -41,7 +40,7 @@ class ShardedTensor(object):
 | 
				
			|||||||
        return self._payload
 | 
					        return self._payload
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def copy_payload(self, tensor):
 | 
					    def copy_payload(self, tensor):
 | 
				
			||||||
        self._payload.copy_(tensor)
 | 
					        self._payload.view(-1).copy_(tensor.view(-1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def reset_payload(self, tensor):
 | 
					    def reset_payload(self, tensor):
 | 
				
			||||||
        del self._payload
 | 
					        del self._payload
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										39
									
								
								tests/test_utils/test_bucket_tensor_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								tests/test_utils/test_bucket_tensor_copy.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,39 @@
 | 
				
			|||||||
 | 
					from colossalai.utils.commons import BucketizedTensorCopy
 | 
				
			||||||
 | 
					from colossalai.zero.sharded_param import ShardedParamV2
 | 
				
			||||||
 | 
					from colossalai.utils import free_port
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import colossalai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_bucket_copy():
 | 
				
			||||||
 | 
					    # init dist env
 | 
				
			||||||
 | 
					    colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    copyer = BucketizedTensorCopy(20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    shape_list = [(2, 3), (5), (8), (12)]
 | 
				
			||||||
 | 
					    src_param_list = []
 | 
				
			||||||
 | 
					    tgt_param_list = []
 | 
				
			||||||
 | 
					    for shape in shape_list:
 | 
				
			||||||
 | 
					        # on CPU
 | 
				
			||||||
 | 
					        src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu')))
 | 
				
			||||||
 | 
					        print(src_param)
 | 
				
			||||||
 | 
					        # on GPU
 | 
				
			||||||
 | 
					        tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda'))))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        src_param_list.append(src_param)
 | 
				
			||||||
 | 
					        tgt_param_list.append(tgt_param)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        copyer.copy(src_param, tgt_param)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    copyer.flush()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for src_param, tgt_param in zip(src_param_list, tgt_param_list):
 | 
				
			||||||
 | 
					        print(tgt_param.data.payload)
 | 
				
			||||||
 | 
					        diff = src_param.cpu().float() - tgt_param.data.payload.cpu().float()
 | 
				
			||||||
 | 
					        assert torch.allclose(src_param.cpu().float(), tgt_param.data.payload.cpu().float(), rtol=1e-03,
 | 
				
			||||||
 | 
					                              atol=1e-03), f"diff {diff}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    test_bucket_copy()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user