[pipeline] update shardformer policy

This commit is contained in:
ver217
2023-07-05 14:16:55 +08:00
committed by Hongxin Liu
parent 90a65ea682
commit 59f6f573f1
5 changed files with 84 additions and 8 deletions

View File

@@ -0,0 +1,19 @@
from typing import Set
import torch.nn as nn
def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
"""Set all parameters and buffers of model to None
Args:
model (nn.Module): The model to set
"""
if model in exclude:
return
for child in model.children():
set_tensors_to_none(child, exclude=exclude)
for n, p in model.named_parameters(recurse=False):
setattr(model, n, None)
for n, buf in model.named_buffers(recurse=False):
setattr(model, n, None)