mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[shardformer] llama support DistCrossEntropy (#5176)
* fix aaa fix fix fix * fix * fix * test ci * fix ci fix * llama support dist-cross fix fix fix fix fix fix fix fix * fix * fix * fix fix * test ci * test ci * fix * [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878) * Add finetuning Colossal-Llama-2 example * Add finetuning Colossal-Llama-2 example 2 * Add finetuning Colossal-Llama-2 example and support NEFTuning * Add inference example and refine neftune * Modify readme file * update the imports --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> * llama support dist-cross fix fix fix fix fix fix fix fix * fix * fix * fix fix * test ci * test ci * fix * fix ci * fix ci --------- Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
This commit is contained in:
@@ -17,23 +17,32 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||
|
||||
# prepare data
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True)
|
||||
labels = torch.randint(8, (2, 4))
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
|
||||
labels = torch.randint(8, (2, 4)).cuda()
|
||||
# set some label to -100 to test the ignore index
|
||||
labels[0, -1] = ignore_index
|
||||
|
||||
org_pred = pred.view(-1, 8)
|
||||
org_labels = labels.view(-1)
|
||||
org_loss = F.cross_entropy(org_pred, org_labels)
|
||||
pred.retain_grad()
|
||||
org_loss.backward()
|
||||
|
||||
dist_pred = pred.chunk(world_size, -1)[rank]
|
||||
dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index)
|
||||
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
|
||||
dist_pred.requires_grad = True
|
||||
dist_loss = cross_entropy_1d(dist_pred, labels, ignore_index=ignore_index)
|
||||
dist_pred.retain_grad()
|
||||
dist_loss.backward()
|
||||
|
||||
assert torch.allclose(
|
||||
org_loss, dist_loss, atol=1e-5
|
||||
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
||||
|
||||
|
||||
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
|
||||
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_crossentropy():
|
||||
|
Reference in New Issue
Block a user