mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[shardformer] Pytree fix (#4533)
* pytree test * test bert * test bert * test bert * revise * add register * add register
This commit is contained in:
@@ -41,6 +41,11 @@ class ChatGLMPolicy(Policy):
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
# the batch_size_dim is bounded to Model
|
||||
bsz_dim = 1
|
||||
setattr(self.model, 'batch_size_dim', bsz_dim)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
Reference in New Issue
Block a user