ColossalAI/colossalai/pipeline/weight_grad_store.py
duanjunwen a9bedc7a43
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 10:22:26 +08:00

43 lines
1.8 KiB
Python

import queue
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if isinstance(weight, tuple):
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
# View will lead to weight ptr change
# weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update
_, weight_origin = weight
if weight_origin.grad is not None:
func(total_input, grad_output, weight_origin.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight_origin.grad = grad_weight
else:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")