mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[pipeline] update shardformer policy
This commit is contained in:
19
colossalai/shardformer/shard/utils.py
Normal file
19
colossalai/shardformer/shard/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Set
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
|
||||
"""Set all parameters and buffers of model to None
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to set
|
||||
"""
|
||||
if model in exclude:
|
||||
return
|
||||
for child in model.children():
|
||||
set_tensors_to_none(child, exclude=exclude)
|
||||
for n, p in model.named_parameters(recurse=False):
|
||||
setattr(model, n, None)
|
||||
for n, buf in model.named_buffers(recurse=False):
|
||||
setattr(model, n, None)
|
Reference in New Issue
Block a user