Commit Graph

320 Commits

Author SHA1 Message Date
ver217
a93a7d7364 [hotfix] fix reuse_fp16_shard of sharded model (#756)
* fix reuse_fp16_shard

* disable test stm

* polish code
2022-04-14 14:56:46 +08:00
ver217
8f7ce94b8e [hotfix] fix auto tensor placement policy (#753) 2022-04-14 12:04:45 +08:00
HELSON
84c6700b2a [zero] refactor memstats_collector (#746) 2022-04-14 12:01:12 +08:00
Jiarui Fang
3d7dc46d33 [zero] use factory pattern for tensor_placement_policy (#752) 2022-04-14 11:07:29 +08:00
ver217
4b048a8728 fix prepare grads in sharded optim (#749) 2022-04-13 22:36:11 +08:00
ver217
e396bb71f2 [zero] add tensor placement policies (#743)
* add tensor placement policies

* polish comments

* polish comments

* update moe unit tests
2022-04-13 15:00:48 +08:00
HELSON
22c4b88d56 [zero] refactor ShardedParamV2 for convenience (#742) 2022-04-13 14:54:26 +08:00
ver217
e6212f56cd [hotfix] fix memory leak in backward of sharded model (#741) 2022-04-13 09:59:05 +08:00
Jiarui Fang
7db3ccc79b [hotfix] remove duplicated param register to stateful tensor manager (#728) 2022-04-12 13:55:25 +08:00
Jiarui Fang
4d90a7b513 [refactor] zero directory (#724) 2022-04-11 23:13:02 +08:00
Jiarui Fang
193dc8dacb [refactor] refactor the memory utils (#715) 2022-04-11 16:47:57 +08:00
HELSON
dbd96fe90a [zero] check whether gradients have inf and nan in gpu (#712) 2022-04-11 15:40:13 +08:00
ver217
715b86eadd [hotfix] fix stm cuda model data size (#710) 2022-04-11 15:10:39 +08:00
HELSON
a9b8300d54 [zero] improve adaptability for not-shard parameters (#708)
* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
2022-04-11 13:38:51 +08:00
ver217
ab8c6b4a0e [zero] refactor memstats collector (#706)
* refactor memstats collector

* fix disposable

* polish code
2022-04-11 10:46:08 +08:00
HELSON
ee112fe1da [zero] adapt zero hooks for unsharded module (#699) 2022-04-08 20:23:26 +08:00
ver217
3c9cd5bb5e [zero] stateful tensor manager (#687)
* [WIP] stateful tensor manager

* add eviction strategy

* polish code

* polish code

* polish comment

* add unit test

* fix sampler bug

* polish code

* fix max sampling cnt resetting bug

* fix sampler bug

* polish code

* fix bug

* fix unit test

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
2022-04-08 17:51:34 +08:00
HELSON
d7ecaf362b [zero] fix init bugs in zero context (#686)
* adapt model weight initialization for methods in Pytorch nn.init
2022-04-07 17:38:45 +08:00
Jiarui Fang
59bf2dc590 [zero] initialize a stateful tensor manager (#614) 2022-04-06 16:18:49 +08:00
HELSON
17e73e62cc [hotfix] fix bugs for unsharded parameters when restore data (#664) 2022-04-03 22:02:11 +08:00
Jiarui Fang
0aab52301e [hotfix] fix a bug in model data stats tracing (#655) 2022-04-03 21:48:06 +08:00
Jiarui Fang
036404ca8a Revert "[zero] polish init context (#645)" (#657) 2022-04-02 18:30:06 +08:00
Jiarui Fang
67b4928244 [zero] polish init context (#645) 2022-04-02 15:52:04 +08:00
HELSON
055fbf5be6 [zero] adapt zero for unsharded paramters (Optimizer part) (#601) 2022-04-01 20:10:47 +08:00
ver217
0ef8819c67 polish docstring of zero (#612) 2022-04-01 14:50:56 +08:00
ver217
9bee119104 [hotfix] fix sharded optim zero grad (#604)
* fix sharded optim zero grad

* polish comments
2022-04-01 12:41:20 +08:00
Jiarui Fang
e956d93ac2 [refactor] memory utils (#577) 2022-04-01 09:22:33 +08:00
HELSON
e6d50ec107 [zero] adapt zero for unsharded parameters (#561)
* support existing sharded and unsharded parameters in zero

* add unitest for moe-zero model init

* polish moe gradient handler
2022-03-31 18:34:11 +08:00
ver217
7c6c427db1 [zero] trace states of fp16/32 grad and fp32 param (#571) 2022-03-31 16:26:54 +08:00
Jiarui Fang
7675366fce [polish] rename col_attr -> colo_attr (#558) 2022-03-31 12:25:45 +08:00
ver217
014bac0c49 [zero] hijack p.grad in sharded model (#554)
* hijack p.grad in sharded model

* polish comments

* polish comments
2022-03-30 18:14:50 +08:00
Jiarui Fang
f552b11294 [zero] label state for param fp16 and grad (#551) 2022-03-30 15:57:46 +08:00
Jiarui Fang
214da761d4 [zero] add stateful tensor (#549) 2022-03-30 13:51:37 +08:00
Jiarui Fang
107b99ddb1 [zero] dump memory stats for sharded model (#548) 2022-03-30 09:38:44 +08:00
HELSON
8c90d4df54 [zero] add zero context manager to change config during initialization (#546) 2022-03-29 17:57:59 +08:00
Jiarui Fang
53b1b6e340 [zero] non model data tracing (#545) 2022-03-29 15:45:48 +08:00
ver217
fb841dd5c5 [zero] optimize grad offload (#539)
* optimize grad offload

* polish code

* polish code
2022-03-29 12:48:00 +08:00
ver217
1f90a3b129 [zero] polish ZeroInitContext (#540) 2022-03-29 09:09:04 +08:00
Jiarui Fang
c11ff81b15 [zero] get memory usage of sharded optim v2. (#542) 2022-03-29 09:08:18 +08:00
HELSON
a30e2b4c24 [zero] adapt for no-leaf module in zero (#535)
only process module's own parameters in Zero context

add zero hooks for all modules that contrain parameters

gather parameters only belonging to module itself
2022-03-28 17:42:18 +08:00
Jiarui Fang
705f56107c [zero] refactor model data tracing (#537) 2022-03-28 16:38:18 +08:00
Jiarui Fang
a590ed0ba3 [zero] improve the accuracy of get_memory_usage of sharded param (#538) 2022-03-28 16:19:19 +08:00
Jiarui Fang
37cb70feec [zero] get memory usage for sharded param (#536) 2022-03-28 15:01:21 +08:00
Jiarui Fang
05e33b2578 [zero] fix grad offload (#528)
* [zero] fix grad offload

* polish code
2022-03-25 18:23:25 +08:00
Jiarui Fang
8d8c5407c0 [zero] refactor model data tracing (#522) 2022-03-25 18:03:32 +08:00
Jiarui Fang
4d322b79da [refactor] remove old zero code (#517) 2022-03-25 14:54:39 +08:00
Jiarui Fang
920c5889a7 [zero] add colo move inline (#521) 2022-03-25 14:02:55 +08:00
Jiarui Fang
0bebda6ea5 [zero] fix init device bug in zero init context unittest (#516) 2022-03-25 12:24:18 +08:00
Jiarui Fang
7ef3507ace [zero] show model data cuda memory usage after zero context init. (#515) 2022-03-25 11:23:35 +08:00
ver217
a2e61d61d4 [zero] zero init ctx enable rm_torch_payload_on_the_fly (#512)
* enable rm_torch_payload_on_the_fly

* polish docstr
2022-03-24 23:44:00 +08:00