mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,8 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import search_chunk_configuration
|
||||
|
||||
@@ -17,15 +15,17 @@ def safe_div(a, b):
|
||||
return a / b
|
||||
|
||||
|
||||
def init_chunk_manager(model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs) -> ChunkManager:
|
||||
def init_chunk_manager(
|
||||
model: nn.Module,
|
||||
init_device: Optional[torch.device] = None,
|
||||
hidden_dim: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> ChunkManager:
|
||||
if hidden_dim:
|
||||
search_interval = hidden_dim
|
||||
else:
|
||||
search_interval = 1024 # defaults to 1024
|
||||
search_interval = 1024 # defaults to 1024
|
||||
kwargs["search_interval"] = search_interval
|
||||
|
||||
dist.barrier()
|
||||
@@ -41,11 +41,13 @@ def init_chunk_manager(model: nn.Module,
|
||||
wasted_size /= mega_unit
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep='',
|
||||
flush=True)
|
||||
print(
|
||||
"searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
|
||||
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
|
||||
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
|
||||
sep="",
|
||||
flush=True,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
chunk_manager = ChunkManager(config_dict, init_device)
|
||||
|
Reference in New Issue
Block a user