mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[fx] support module with bias addition (#1780)
* [autoparallel] refactor tracer to fix bias addition issue * [fx] support module with bias addition * create bias_addition_module * refactor file structure * polish code * fix unit test
This commit is contained in:
@@ -18,11 +18,10 @@ from torch.fx import Node, Tracer
|
||||
from torch.fx.graph import Graph, magic_methods, reflectable_magic_methods
|
||||
from torch.fx.proxy import ParameterProxy, Proxy
|
||||
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_module
|
||||
|
||||
from ..proxy import ColoProxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .meta_patch import meta_patched_function, meta_patched_module
|
||||
from .bias_addition_patch import module_to_func_dict
|
||||
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
@@ -79,18 +78,126 @@ class ColoTracer(Tracer):
|
||||
"""
|
||||
Create a proxy for different kinds of operations.
|
||||
"""
|
||||
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
|
||||
if self.tracer_type == TracerType.DEFAULT:
|
||||
# since meta_args is not given
|
||||
# we just fall back to the original torch.fx.Tracer
|
||||
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
return proxy
|
||||
|
||||
# if graph is traced for auto parallelism module, some extra node will be added during
|
||||
# graph construction to deal with the compatability between bias addition and all reduce.
|
||||
|
||||
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
|
||||
# to create node on computation graph
|
||||
origin_arguments = (kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
# dispatch the arguments generator depending on the kind and target in origin arguments.
|
||||
args_metas, _ = extract_meta(*args, **kwargs)
|
||||
if kind == "call_function":
|
||||
if bias_addition_function.has(target):
|
||||
return bias_addition_function.get(target)(self, target, args, kwargs)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
return bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_function.has(method):
|
||||
return bias_addition_function.get(method)(self, target, args, kwargs)
|
||||
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
raise AttributeError(f"{self} does not have an attribute called orig_forward")
|
||||
self._disable_module_getattr = True
|
||||
try:
|
||||
mod = self.root.get_submodule(target)
|
||||
mod_type = type(mod)
|
||||
if bias_addition_module.has(mod_type) and mod.bias is not None:
|
||||
function_to_substitute = module_to_func_dict[mod_type]
|
||||
handle = bias_addition_module.get(mod_type)(self, target, args, kwargs, function_to_substitute)
|
||||
return handle.generate()
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
|
||||
# create nodes using patched arguments
|
||||
proxy = super().create_proxy(*origin_arguments)
|
||||
proxy: ColoProxy
|
||||
meta_out = self._meta_data_computing(
|
||||
kind,
|
||||
target,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
proxy.meta_data = meta_out
|
||||
|
||||
return proxy
|
||||
|
||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if getattr(self, "_disable_module_getattr", False):
|
||||
return attr_val
|
||||
else:
|
||||
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||
for n, p in collection_to_search:
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
|
||||
lambda node: ParameterProxy(self, node, n, attr_val))
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||
parameter_proxy_cache)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
|
||||
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
|
||||
parameter_proxy_cache)
|
||||
if maybe_buffer_proxy is not None:
|
||||
return maybe_buffer_proxy
|
||||
|
||||
return attr_val
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
self.orig_forward = forward
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
|
||||
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
|
||||
# which means customized modules are not leaf module by default
|
||||
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
|
||||
# we should treat it as leaf module as well
|
||||
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
|
||||
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
|
||||
else:
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
def proxy(self, node) -> Proxy:
|
||||
"""
|
||||
Returns a ColoProxy object.
|
||||
"""
|
||||
return self.proxy_cls(node, self)
|
||||
|
||||
def _configure_tracer_type(self, tracer_type: TracerType):
|
||||
if tracer_type == TracerType.DEFAULT:
|
||||
self.proxy_cls = Proxy
|
||||
self.tracer_type = TracerType.DEFAULT
|
||||
elif tracer_type == TracerType.META:
|
||||
self.proxy_cls = ColoProxy
|
||||
self.tracer_type = TracerType.META
|
||||
else:
|
||||
raise ValueError(f"Unrecognised tracer type {tracer_type}")
|
||||
|
||||
def _meta_data_computing(self, kind, target, args, kwargs):
|
||||
|
||||
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
|
||||
proxy.meta_data = self.meta_args[target]
|
||||
return proxy
|
||||
meta_out = self.meta_args[target]
|
||||
return meta_out
|
||||
|
||||
if target in self.orig_torch_tensor_methods:
|
||||
# NOTE: tensor constructors in PyTorch define the `device` argument as
|
||||
@@ -154,75 +261,12 @@ class ColoTracer(Tracer):
|
||||
finally:
|
||||
self._disable_module_getattr = False
|
||||
else:
|
||||
return proxy
|
||||
|
||||
if not isinstance(proxy, Proxy):
|
||||
raise ValueError("Don't support composite output yet")
|
||||
proxy.meta_data = meta_out
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
|
||||
return proxy
|
||||
|
||||
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
|
||||
if getattr(self, "_disable_module_getattr", False):
|
||||
return attr_val
|
||||
else:
|
||||
# return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
|
||||
for n, p in collection_to_search:
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else
|
||||
lambda node: ParameterProxy(self, node, n, attr_val))
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||
parameter_proxy_cache)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not compute metadata for {kind} target {target}: {e}")
|
||||
|
||||
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
||||
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(),
|
||||
parameter_proxy_cache)
|
||||
if maybe_buffer_proxy is not None:
|
||||
return maybe_buffer_proxy
|
||||
|
||||
return attr_val
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
self.orig_forward = forward
|
||||
module_qualified_name = self.path_of_module(m)
|
||||
|
||||
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
|
||||
# which means customized modules are not leaf module by default
|
||||
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
|
||||
# we should treat it as leaf module as well
|
||||
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
|
||||
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
|
||||
else:
|
||||
return forward(*args, **kwargs)
|
||||
|
||||
def proxy(self, node) -> Proxy:
|
||||
"""
|
||||
Returns a ColoProxy object.
|
||||
"""
|
||||
return self.proxy_cls(node, self)
|
||||
|
||||
def _configure_tracer_type(self, tracer_type: TracerType):
|
||||
if tracer_type == TracerType.DEFAULT:
|
||||
self.proxy_cls = Proxy
|
||||
self.tracer_type = TracerType.DEFAULT
|
||||
elif tracer_type == TracerType.META:
|
||||
self.proxy_cls = ColoProxy
|
||||
self.tracer_type = TracerType.META
|
||||
else:
|
||||
raise ValueError(f"Unrecognised tracer type {tracer_type}")
|
||||
return meta_out
|
||||
|
||||
def trace(self,
|
||||
root: nn.Module,
|
||||
|
Reference in New Issue
Block a user