fix conversation; support sleep mode

This commit is contained in:
YeAnbang 2025-04-23 14:43:54 +08:00
parent 39a1faac87
commit ab55b9f033
6 changed files with 29 additions and 14 deletions

View File

@ -158,6 +158,7 @@ class SimpleConsumer(BaseConsumer):
model_config,
plugin_config,
microbatch_size=1,
save_dir="./model",
):
super().__init__(
num_producers,

View File

@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
generate_config=None,
grpo_config={},
project_name=None,
save_dir="./model",
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@ -63,6 +64,7 @@ class GRPOConsumer(BaseConsumer):
model_config,
plugin_config,
microbatch_size,
save_dir=save_dir,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -173,7 +175,7 @@ class GRPOConsumer(BaseConsumer):
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward_group = self.reward_model(
data["input_ids"],
@ -222,11 +224,11 @@ class GRPOConsumer(BaseConsumer):
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
pbar.set_postfix(
{
"Step": self.global_step + 1,
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}",
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
}
)
@ -237,23 +239,23 @@ class GRPOConsumer(BaseConsumer):
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
if self.plugin.pp_size > 1:
@ -442,7 +444,7 @@ class GRPOConsumer(BaseConsumer):
[
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",

View File

@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
class VLLMInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True,
enable_sleep_mode=False,
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations

View File

@ -107,6 +107,7 @@ def launch_distributed(
grpo_config=grpo_config,
num_generations=num_generations,
project_name=project_name,
save_dir=grpo_config.get("save_dir", f"./model/{project_name}"),
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])

View File

@ -113,6 +113,10 @@ class BaseProducer:
if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
):
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
@ -125,6 +129,10 @@ class BaseProducer:
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)

View File

@ -62,6 +62,7 @@ if __name__ == "__main__":
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size"
if args.master_address is None:
# Default settings: Using single machine
@ -71,7 +72,7 @@ if __name__ == "__main__":
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
if args.backend == "transformers":
@ -96,13 +97,13 @@ if __name__ == "__main__":
gpu_memory_utilization=0.7,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=1024 * 10 + 510,
max_model_len=1024 * 4 + 510,
tensor_parallel_size=1,
)
)
generate_config.update(
dict(
max_tokens=1024 * 10,
max_tokens=1024 * 4,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
@ -139,7 +140,7 @@ if __name__ == "__main__":
"beta": 0.0, # no KL penalty
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": 1024 * 10,
"max_length": 1024 * 4,
"cache_length": 512,
"filter_truncated_response": True,
}