mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -38,8 +38,7 @@ def _binary_partition(weights: List, start: int, end: int):
|
||||
|
||||
|
||||
def _heap_addition(weights: List, intervals: int, add_cnt: int):
|
||||
"""
|
||||
"""
|
||||
""" """
|
||||
|
||||
def _heap_push(heap, st, ed):
|
||||
value = weights[ed - 1]
|
||||
@@ -113,8 +112,9 @@ def _binary_search(weights, num):
|
||||
|
||||
|
||||
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
|
||||
assert num_items % num_chunks == 0, \
|
||||
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
|
||||
assert (
|
||||
num_items % num_chunks == 0
|
||||
), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
|
||||
|
||||
logger = get_dist_logger()
|
||||
parts = [[] for _ in range(pipeline_parallel_size)]
|
||||
@@ -162,7 +162,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
kwargs_offset = 1
|
||||
elif isinstance(input_tensor, (tuple, OrderedDict)):
|
||||
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
# assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
# Huggingface will take their own structures based on OrderedDict as the output
|
||||
# between layers so we've to close this check.
|
||||
kwargs_offset = len(input_tensor)
|
||||
@@ -204,21 +204,21 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
||||
kwargs[k] = rst
|
||||
return input_tensor
|
||||
if isinstance(input_tensor, tuple):
|
||||
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
|
||||
assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None."
|
||||
sig = inspect.signature(func)
|
||||
func_args_num = len(sig.parameters)
|
||||
assert func_args_num <= len(
|
||||
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
|
||||
input_tensor
|
||||
), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}."
|
||||
if func_args_num < len(input_tensor):
|
||||
return func(*input_tensor[:func_args_num])
|
||||
else:
|
||||
return func(*input_tensor)
|
||||
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
|
||||
assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple."
|
||||
return func(input_tensor)
|
||||
|
||||
|
||||
def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
||||
|
||||
assert func_key in func_dict, f"{func_key} is not in the function_dict."
|
||||
funcs_to_exec = func_dict[func_key]
|
||||
if isinstance(funcs_to_exec, list):
|
||||
@@ -243,7 +243,7 @@ def call_module(module, args=None, kwargs=None):
|
||||
forward_func = module.forward
|
||||
sig = inspect.signature(forward_func)
|
||||
param_nums = len(sig.parameters)
|
||||
feed_nums = len(args) + len(kwargs)
|
||||
len(args) + len(kwargs)
|
||||
args_needed_nums = param_nums - len(kwargs)
|
||||
args_needed = args[:args_needed_nums]
|
||||
if isinstance(module, CheckpointModule):
|
||||
@@ -256,17 +256,17 @@ def call_module(module, args=None, kwargs=None):
|
||||
|
||||
|
||||
def customized_partition(exec_seq):
|
||||
'''
|
||||
"""
|
||||
This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
|
||||
annotation to note the partition point.
|
||||
'''
|
||||
"""
|
||||
customized_parts = {}
|
||||
start = 0
|
||||
stop = 0
|
||||
rank = 0
|
||||
for element in exec_seq:
|
||||
if isinstance(element, str):
|
||||
if element == 'SPLIT_NODE':
|
||||
if element == "SPLIT_NODE":
|
||||
customized_parts[rank] = [(start, stop)]
|
||||
start = stop
|
||||
rank += 1
|
||||
|
Reference in New Issue
Block a user