mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
add DAPO support
This commit is contained in:
parent
9474316132
commit
c0c0da2f26
@ -112,7 +112,7 @@ class BaseConsumer:
|
|||||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||||
batch = bind_batch(batches)
|
batch = bind_batch(batches)
|
||||||
batch = post_recv(batch)
|
batch = post_recv(batch)
|
||||||
loss = self.step(i, **batch)
|
loss = self.step(i, pbar, **batch)
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
@ -181,7 +181,7 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
super().setup()
|
super().setup()
|
||||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||||
|
|
||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||||
labels = kwargs["input_ids"].clone()
|
labels = kwargs["input_ids"].clone()
|
||||||
labels[kwargs["attention_mask"] == 0] = -100
|
labels[kwargs["attention_mask"] == 0] = -100
|
||||||
kwargs["labels"] = labels
|
kwargs["labels"] = labels
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional
|
from typing import Optional, Any
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -144,7 +144,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||||
self.plugin.logger.set_level("ERROR")
|
self.plugin.logger.set_level("ERROR")
|
||||||
|
|
||||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
Step data from policy model:
|
Step data from policy model:
|
||||||
[{
|
[{
|
||||||
@ -212,6 +212,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||||
# balance between efficiency and accuracy
|
# balance between efficiency and accuracy
|
||||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
||||||
|
pbar.set_postfix({"Step": self.global_step + 1, "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.75}"})
|
||||||
|
|
||||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||||
ctx = (
|
ctx = (
|
||||||
@ -409,6 +410,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if need_update:
|
if need_update:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
self.global_step += 1
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.total_sample_count = 0
|
self.total_sample_count = 0
|
||||||
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
|||||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=2048,
|
max_tokens=4096,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
@ -120,8 +120,8 @@ if __name__ == "__main__":
|
|||||||
"beta": 0.0, # no KL penalty
|
"beta": 0.0, # no KL penalty
|
||||||
"loss_variation": "token_level",
|
"loss_variation": "token_level",
|
||||||
"soft_over_length_punishment": True,
|
"soft_over_length_punishment": True,
|
||||||
"max_length": 1024 * 2,
|
"max_length": 4096,
|
||||||
"cache_length": 256,
|
"cache_length": 512,
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user