mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-06 06:02:16 +00:00
fix conversation; support sleep mode
This commit is contained in:
parent
39a1faac87
commit
ab55b9f033
@ -158,6 +158,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
save_dir="./model",
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
|
@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
generate_config=None,
|
||||
grpo_config={},
|
||||
project_name=None,
|
||||
save_dir="./model",
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
@ -63,6 +64,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@ -173,7 +175,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"],
|
||||
@ -222,11 +224,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||
# balance between efficiency and accuracy
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Step": self.global_step + 1,
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}",
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
}
|
||||
)
|
||||
|
||||
@ -237,23 +239,23 @@ class GRPOConsumer(BaseConsumer):
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
|
||||
if self.plugin.pp_size > 1:
|
||||
@ -442,7 +444,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
[
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
||||
f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
|
@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
DEFAULT_MODEL_CONFIG = dict(
|
||||
trust_remote_code=True,
|
||||
enable_sleep_mode=False,
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(
|
||||
logprobs=0,
|
||||
@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
|
||||
|
@ -107,6 +107,7 @@ def launch_distributed(
|
||||
grpo_config=grpo_config,
|
||||
num_generations=num_generations,
|
||||
project_name=project_name,
|
||||
save_dir=grpo_config.get("save_dir", f"./model/{project_name}"),
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
@ -113,6 +113,10 @@ class BaseProducer:
|
||||
if (i + 1) % self.num_microbatches == 0 and (
|
||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||
):
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
# don't sync model for last iteration
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
@ -125,6 +129,10 @@ class BaseProducer:
|
||||
self.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.wake_up()
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||
|
@ -62,6 +62,7 @@ if __name__ == "__main__":
|
||||
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size"
|
||||
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
@ -71,7 +72,7 @@ if __name__ == "__main__":
|
||||
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
|
||||
|
||||
if args.backend == "transformers":
|
||||
@ -96,13 +97,13 @@ if __name__ == "__main__":
|
||||
gpu_memory_utilization=0.7,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=1024 * 10 + 510,
|
||||
max_model_len=1024 * 4 + 510,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=1024 * 10,
|
||||
max_tokens=1024 * 4,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
@ -139,7 +140,7 @@ if __name__ == "__main__":
|
||||
"beta": 0.0, # no KL penalty
|
||||
"loss_variation": "token_level",
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": 1024 * 10,
|
||||
"max_length": 1024 * 4,
|
||||
"cache_length": 512,
|
||||
"filter_truncated_response": True,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user