[shardformer] llama support DistCrossEntropy (#5176)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* fix ci

* fix ci

---------

Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
This commit is contained in:
flybird11111
2023-12-13 01:39:14 +08:00
committed by GitHub
parent cefdc32615
commit 79718fae04
5 changed files with 143 additions and 13 deletions

View File

@@ -17,23 +17,32 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.randint(8, (2, 4))
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
labels = torch.randint(8, (2, 4)).cuda()
# set some label to -100 to test the ignore index
labels[0, -1] = ignore_index
org_pred = pred.view(-1, 8)
org_labels = labels.view(-1)
org_loss = F.cross_entropy(org_pred, org_labels)
pred.retain_grad()
org_loss.backward()
dist_pred = pred.chunk(world_size, -1)[rank]
dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index)
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
dist_pred.requires_grad = True
dist_loss = cross_entropy_1d(dist_pred, labels, ignore_index=ignore_index)
dist_pred.retain_grad()
dist_loss.backward()
assert torch.allclose(
org_loss, dist_loss, atol=1e-5
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_crossentropy():