mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -29,7 +29,7 @@ from .registry import (
|
||||
meta_patched_module,
|
||||
)
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
__all__ = ["ColoTracer"]
|
||||
|
||||
|
||||
class TracerType(enum.Enum):
|
||||
@@ -103,7 +103,7 @@ class ColoTracer(Tracer):
|
||||
if kind == "call_function":
|
||||
if bias_addition_function.has(target):
|
||||
if target == torch.nn.functional.linear:
|
||||
if 'bias' in kwargs and kwargs['bias'] is not None:
|
||||
if "bias" in kwargs and kwargs["bias"] is not None:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
else:
|
||||
@@ -160,22 +160,27 @@ class ColoTracer(Tracer):
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
|
||||
lambda node: ParameterProxy(self, node, n, attr_val))
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
kwargs["proxy_factory_fn"] = (
|
||||
None
|
||||
if not self.param_shapes_constant
|
||||
else lambda node: ParameterProxy(self, node, n, attr_val)
|
||||
)
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||
parameter_proxy_cache)
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
||||
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
||||
)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
|
||||
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
|
||||
parameter_proxy_cache)
|
||||
maybe_buffer_proxy = maybe_get_proxy_for_attr(
|
||||
attr_val, self.root.named_buffers(), parameter_proxy_cache
|
||||
)
|
||||
if maybe_buffer_proxy is not None:
|
||||
return maybe_buffer_proxy
|
||||
|
||||
@@ -190,7 +195,7 @@ class ColoTracer(Tracer):
|
||||
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
|
||||
# we should treat it as leaf module as well
|
||||
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
|
||||
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
|
||||
return self.create_proxy("call_module", module_qualified_name, args, kwargs)
|
||||
else:
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
@@ -211,7 +216,6 @@ class ColoTracer(Tracer):
|
||||
raise ValueError(f"Unrecognized tracer type {tracer_type}")
|
||||
|
||||
def _meta_data_computing(self, kind, target, args, kwargs):
|
||||
|
||||
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
|
||||
meta_out = self.meta_args[target]
|
||||
return meta_out
|
||||
@@ -235,8 +239,9 @@ class ColoTracer(Tracer):
|
||||
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation
|
||||
# added by the bias addition manipulation following the get_attr node.
|
||||
convert_to_parameter = False
|
||||
if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0],
|
||||
torch.nn.parameter.Parameter):
|
||||
if target in (torch.transpose, torch.reshape) and isinstance(
|
||||
args_metas[0], torch.nn.parameter.Parameter
|
||||
):
|
||||
convert_to_parameter = True
|
||||
# fetch patched function
|
||||
if meta_patched_function.has(target):
|
||||
@@ -309,10 +314,12 @@ class ColoTracer(Tracer):
|
||||
|
||||
return meta_out
|
||||
|
||||
def trace(self,
|
||||
root: nn.Module,
|
||||
concrete_args: Optional[Dict[str, Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph:
|
||||
def trace(
|
||||
self,
|
||||
root: nn.Module,
|
||||
concrete_args: Optional[Dict[str, Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, Tensor]] = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
|
||||
|
||||
@@ -341,9 +348,7 @@ class ColoTracer(Tracer):
|
||||
# update concrete args with default values
|
||||
non_meta_arg_names = sig_names - meta_arg_names
|
||||
for k, v in sig.parameters.items():
|
||||
if k in non_meta_arg_names and \
|
||||
k not in concrete_args and \
|
||||
v.default is not inspect.Parameter.empty:
|
||||
if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
|
||||
concrete_args[k] = v.default
|
||||
|
||||
# get non concrete arg names
|
||||
@@ -354,7 +359,8 @@ class ColoTracer(Tracer):
|
||||
success, element = is_element_in_list(names, sig_names)
|
||||
if not success:
|
||||
raise KeyError(
|
||||
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function")
|
||||
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
|
||||
)
|
||||
|
||||
_check_arg_name_valid(meta_arg_names)
|
||||
_check_arg_name_valid(concrete_arg_names)
|
||||
@@ -363,11 +369,13 @@ class ColoTracer(Tracer):
|
||||
def _check_kwargs(kwargs, should_be_meta: bool):
|
||||
for k, v in kwargs.items():
|
||||
if not should_be_meta:
|
||||
assert not torch.is_tensor(v) or not v.is_meta, \
|
||||
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer'
|
||||
assert (
|
||||
not torch.is_tensor(v) or not v.is_meta
|
||||
), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer"
|
||||
else:
|
||||
assert v.is_meta == should_be_meta, \
|
||||
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer'
|
||||
assert (
|
||||
v.is_meta == should_be_meta
|
||||
), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer"
|
||||
|
||||
_check_kwargs(concrete_args, should_be_meta=False)
|
||||
_check_kwargs(meta_args, should_be_meta=True)
|
||||
@@ -442,7 +450,6 @@ class ColoTracer(Tracer):
|
||||
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
|
||||
|
||||
class PatchedCheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
# signal that the current tracing occurs within activation checkpoint part
|
||||
@@ -455,7 +462,8 @@ class ColoTracer(Tracer):
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"We do not implement the backward pass as we only trace the forward pass.")
|
||||
"We do not implement the backward pass as we only trace the forward pass."
|
||||
)
|
||||
|
||||
# override the checkpoint function
|
||||
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
|
||||
@@ -470,12 +478,11 @@ class ColoTracer(Tracer):
|
||||
|
||||
if self.inside_torch_checkpoint_func:
|
||||
# annotate the activation checkpoint module
|
||||
node.meta['activation_checkpoint'] = self.act_ckpt_region_count
|
||||
node.meta["activation_checkpoint"] = self.act_ckpt_region_count
|
||||
return node
|
||||
|
||||
|
||||
def wrap_tensor_constructor_method(target):
|
||||
|
||||
def look_for_proxy(*args, **kwargs):
|
||||
# find in pos vars
|
||||
for arg in args:
|
||||
@@ -518,12 +525,10 @@ def wrap_tensor_constructor_method(target):
|
||||
for method in magic_methods:
|
||||
|
||||
def _scope(method):
|
||||
|
||||
def impl(*args, **kwargs):
|
||||
|
||||
tracer = args[0].tracer
|
||||
target = getattr(operator, method)
|
||||
proxy = tracer.create_proxy('call_function', target, args, kwargs)
|
||||
proxy = tracer.create_proxy("call_function", target, args, kwargs)
|
||||
if not isinstance(proxy, ColoProxy):
|
||||
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
|
||||
proxy = ColoProxy(proxy.node)
|
||||
@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name):
|
||||
|
||||
def impl(self, rhs):
|
||||
target = getattr(operator, orig_method_name)
|
||||
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {})
|
||||
proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {})
|
||||
if not isinstance(proxy, ColoProxy):
|
||||
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
|
||||
proxy = ColoProxy(proxy.node)
|
||||
|
Reference in New Issue
Block a user