mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py code style (#2368)
This commit is contained in:
		| @@ -5,9 +5,9 @@ from functools import reduce | |||||||
| from typing import Dict, List | from typing import Dict, List | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ |  | ||||||
|     ignore_sharding_exception | from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception | ||||||
| from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) | from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector | ||||||
| from colossalai.tensor.shape_consistency import ShapeConsistencyManager | from colossalai.tensor.shape_consistency import ShapeConsistencyManager | ||||||
| from colossalai.tensor.sharding_spec import ShardingSpec | from colossalai.tensor.sharding_spec import ShardingSpec | ||||||
|  |  | ||||||
| @@ -42,19 +42,19 @@ class EmbeddingHandler(OperatorHandler): | |||||||
|         Argument: |         Argument: | ||||||
|             sharding_size_forward(int): The forward activation will be divided |             sharding_size_forward(int): The forward activation will be divided | ||||||
|                 into sharding_size_forward number partions. |                 into sharding_size_forward number partions. | ||||||
|             sharding_size_backward_activation(int): The backward activation will  |             sharding_size_backward_activation(int): The backward activation will | ||||||
|                 be divided into sharding_size_backward_activation number partions. |                 be divided into sharding_size_backward_activation number partions. | ||||||
|             sharding_size_weight(int): The backward weight will be divided |             sharding_size_weight(int): The backward weight will be divided | ||||||
|                 into sharding_size_weight number partions. |                 into sharding_size_weight number partions. | ||||||
|  |  | ||||||
|         Return: |         Return: | ||||||
|             memory_cost(Tuple[float]): Memory cost per device with this  |             memory_cost(Tuple[float]): Memory cost per device with this | ||||||
|                 specific strategy, the first element of this tuple is forward |                 specific strategy, the first element of this tuple is forward | ||||||
|                 memory cost, and the second element of this tuple is backward |                 memory cost, and the second element of this tuple is backward | ||||||
|                 memory cost. |                 memory cost. | ||||||
|             memory_cost_forward(float): Memory cost of forward activation per  |             memory_cost_forward(float): Memory cost of forward activation per | ||||||
|                 device with this specific strategy. |                 device with this specific strategy. | ||||||
|             memory_cost_backward_activation(float): Memory cost of backward activation  |             memory_cost_backward_activation(float): Memory cost of backward activation | ||||||
|                 per device with this specific strategy. |                 per device with this specific strategy. | ||||||
|         ''' |         ''' | ||||||
|         # compute the memory cost of this strategy |         # compute the memory cost of this strategy | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user