[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
d2h_rid = fwd_info.get('d2h_rid', None)
d2h_rid = fwd_info.get("d2h_rid", None)
if d2h_rid is not None:
free_region = GlobalRuntimeInfo().region_list[d2h_rid]
assert isinstance(free_region, Region)
free_region.free_cuda_data()
h2d_rid = fwd_info.get('h2d_rid', None)
h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
h2d_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(h2d_region, Region)
@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def forward(ctx, input_, fwd_info, bwd_info):
ctx.bwd_info = bwd_info
sync_rid = fwd_info.get('sync_rid', None)
sync_rid = fwd_info.get("sync_rid", None)
if sync_rid is not None:
prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None)
if prefetch_event:
prefetch_event.wait()
h2d_rid = fwd_info.get('h2d_rid', None)
h2d_rid = fwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
sync_rid = ctx.bwd_info.get('sync_rid', None)
sync_rid = ctx.bwd_info.get("sync_rid", None)
if sync_rid is not None:
wait_region = GlobalRuntimeInfo().region_list[sync_rid]
assert isinstance(wait_region, Region)
@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
else:
wait_region.move_param_to_cuda()
h2d_rid = ctx.bwd_info.get('h2d_rid', None)
h2d_rid = ctx.bwd_info.get("h2d_rid", None)
if h2d_rid is not None:
pref_region = GlobalRuntimeInfo().region_list[h2d_rid]
assert isinstance(pref_region, Region)
@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
"""
Convert Upload and Offload operation into runtime action.
Argument:
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
'''
"""
with torch._C.DisableTorchFunction():
ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
'''
"""
Convert Prefetch and Offload operation into runtime action.
Argument:
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
'''
"""
with torch._C.DisableTorchFunction():
ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
return ret
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
fwd_info = {}
if requires_upload_p_in_fwd(region_list[region.shared_rid]):
fwd_info['h2d_rid'] = region.r_id
fwd_info["h2d_rid"] = region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward upload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
new_node = mod_graph.create_node(
"call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)
)
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p = [region for region in region_list if region.param_size][0]
fwd_info = {"h2d_rid": first_region_with_p.r_id}
with mod_graph.inserting_after(last_inp_node):
upload_apply_node = mod_graph.create_node('call_function',
convert_fwd_upload_bwd_offload_to_action,
args=(last_inp_node, fwd_info, {}))
upload_apply_node = mod_graph.create_node(
"call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})
)
replace_node_users(last_inp_node, upload_apply_node)
last_inp_node = upload_apply_node
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
fwd_info = {}
if region.param_size:
fwd_info['sync_rid'] = region.r_id
fwd_info["sync_rid"] = region.r_id
fwd_prefetch_region = region.fwd_prefetch_region
if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
fwd_info["h2d_rid"] = fwd_prefetch_region.r_id
# forward offload
if r_idx > 0 and region_list[r_idx - 1].need_offload:
fwd_info['d2h_rid'] = r_idx - 1
fwd_info["d2h_rid"] = r_idx - 1
bwd_info = {}
# backward prefetch
if r_idx > 0 and region_list[r_idx - 1].need_offload:
bwd_info['sync_rid'] = r_idx - 1
bwd_info["sync_rid"] = r_idx - 1
if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region:
bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id
bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id
if fwd_info or bwd_info:
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info))
new_node = mod_graph.create_node(
"call_function",
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, fwd_info, bwd_info),
)
replace_node_users(last_inp_node, new_node)
last_inp_node = region.nodes[-1]
if region.bwd_prefetch_region:
bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id}
with mod_graph.inserting_after(last_inp_node):
new_node = mod_graph.create_node('call_function',
convert_fwd_prefetch_bwd_offload_to_action,
args=(last_inp_node, {}, bwd_info))
new_node = mod_graph.create_node(
"call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)
)
replace_node_users(last_inp_node, new_node)
# gm.graph.print_tabular()
return gm