mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[feat] support zbv in mixtral benchmark; (#6083)
* [feat] support zbv in mixtral benchmark; * [fix] MixtralForCausalLMPolicy get_held_layer support zbv; * [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv; * [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv * [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling; * [feat] Linear1D_COL/ROW support zbv WeightGradStore; * [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy; * [fix] fix test case; moe error in second iter * [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv; * [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd; * [fix] debug zbv llama test; * [fix] rm use_zbv flag in Shardconfig; rm debug info; * [fix] add & fix llama test * [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp); * [fix\ fix fail case test_shard_llama * [fix] fix test_shard_llama * [fix] fix llama modeling policy; * [fix] fix test_shard_llama ci; * [fix] fix test zerobubble * [fix] fix handle name; rm useless comments; * [fix] fix send recv signature; * [fix] fix comment in llama & benchmark * [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore * [fix] fix linear (no tp) ops func name;
This commit is contained in:
32
colossalai/pipeline/weight_grad_store.py
Normal file
32
colossalai/pipeline/weight_grad_store.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import queue
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
cache = []
|
||||
weight_grad_queue = [queue.Queue(), queue.Queue()]
|
||||
|
||||
@classmethod
|
||||
def put(cls, total_input, grad_output, weight, func):
|
||||
# func(total_input, grad_output, weight.main_grad)
|
||||
cls.cache.append((total_input, grad_output, weight, func))
|
||||
|
||||
@classmethod
|
||||
def flush(cls, chunk=0):
|
||||
cls.weight_grad_queue[chunk].put(cls.cache)
|
||||
cls.cache = []
|
||||
|
||||
@classmethod
|
||||
def pop(cls, chunk=0):
|
||||
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
|
||||
if cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
for total_input, grad_output, weight, func in stored_grads:
|
||||
if weight.grad is not None:
|
||||
func(total_input, grad_output, weight.grad)
|
||||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
else:
|
||||
grad_weight = func(total_input, grad_output)
|
||||
weight.grad = grad_weight
|
||||
else:
|
||||
raise Exception("Pop empty queue.")
|
Reference in New Issue
Block a user