ColossalAI/colossalai/inference/pipeline/utils.py
Bin Jia 08a9f76b2f
[Pipeline Inference] Sync pipeline inference branch to main (#4820)
* [pipeline inference] pipeline inference (#4492)

* add pp stage manager as circle stage

* fix a bug when create process group

* add ppinfer basic framework

* add micro batch manager and support kvcache-pp gpt2 fwd

* add generate schedule

* use mb size to control mb number

* support generate with kv cache

* add output, remove unused code

* add test

* reuse shardformer to build model

* refactor some code and use the same attribute name of hf

* fix review and add test for generation

* remove unused file

* fix CI

* add cache clear

* fix code error

* fix typo

* [Pipeline inference] Modify to tieweight (#4599)

* add pp stage manager as circle stage

* fix a bug when create process group

* add ppinfer basic framework

* add micro batch manager and support kvcache-pp gpt2 fwd

* add generate schedule

* use mb size to control mb number

* support generate with kv cache

* add output, remove unused code

* add test

* reuse shardformer to build model

* refactor some code and use the same attribute name of hf

* fix review and add test for generation

* remove unused file

* modify the way of saving newtokens

* modify to tieweight

* modify test

* remove unused file

* solve review

* add docstring

* [Pipeline inference] support llama pipeline inference (#4647)

* support llama pipeline inference

* remove tie weight operation

* [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708)

* add benchmark verbose

* fix export tokens

* fix benchmark verbose

* add P2POp style to do p2p communication

* modify schedule as p2p type when ppsize is 2

* remove unused code and add docstring

* [Pipeline inference] Refactor code, add docsting, fix bug (#4790)

* add benchmark script

* update argparse

* fix fp16 load

* refactor code style

* add docstring

* polish code

* fix test bug

* [Pipeline inference] Add pipeline inference docs (#4817)

* add readme doc

* add a ico

* Add performance

* update table of contents

* refactor code (#4873)
2023-10-11 11:40:06 +08:00

36 lines
1.1 KiB
Python

from typing import List, Optional, Set
import torch.nn as nn
from colossalai.shardformer._utils import getattr_, setattr_
def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
"""
Set all parameters and buffers of model to None
Args:
model (nn.Module): The model to set
"""
for module_suffix in include:
set_module = getattr_(model, module_suffix)
for n, p in set_module.named_parameters():
setattr_(set_module, n, None)
for n, buf in set_module.named_buffers():
setattr_(set_module, n, None)
setattr_(model, module_suffix, None)
def get_suffix_name(suffix: str, name: str):
"""
Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
and 'name' when `suffix` is empty.
Args:
suffix (str): The suffix of the suffix module
name (str): The name of the current module
"""
point = '' if suffix is '' else '.'
suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}'
return suffix_name