mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 21:51:57 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Tracer
|
||||
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
|
||||
|
||||
try:
|
||||
from ..codegen import ActivationCheckpointCodeGen
|
||||
|
||||
SUPPORT_ACTIVATION = True
|
||||
except:
|
||||
SUPPORT_ACTIVATION = False
|
||||
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
|
||||
|
||||
|
||||
def _default_device():
|
||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
|
||||
def _current_device(module: torch.nn.Module):
|
||||
@@ -144,10 +145,9 @@ def symbolic_trace(
|
||||
if meta_args:
|
||||
device, orig_device = _default_device(), _current_device(root)
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
|
||||
root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
|
||||
)
|
||||
if trace_act_ckpt and SUPPORT_ACTIVATION:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
root.to(orig_device)
|
||||
|
Reference in New Issue
Block a user