[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)

* fix flash attn

* fix

fix
This commit is contained in:
flybird11111
2023-11-22 16:00:07 +08:00
committed by GitHub
parent 75af66cd81
commit aae496631c
6 changed files with 16 additions and 8 deletions

View File

@@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor