mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
fix conversation; support sleep mode
This commit is contained in:
parent
39a1faac87
commit
ab55b9f033
@ -158,6 +158,7 @@ class SimpleConsumer(BaseConsumer):
|
|||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size=1,
|
microbatch_size=1,
|
||||||
|
save_dir="./model",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
|
@ -41,6 +41,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
generate_config=None,
|
generate_config=None,
|
||||||
grpo_config={},
|
grpo_config={},
|
||||||
project_name=None,
|
project_name=None,
|
||||||
|
save_dir="./model",
|
||||||
):
|
):
|
||||||
print(f"Using GRPO config: {grpo_config}")
|
print(f"Using GRPO config: {grpo_config}")
|
||||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||||
@ -63,6 +64,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
model_config,
|
model_config,
|
||||||
plugin_config,
|
plugin_config,
|
||||||
microbatch_size,
|
microbatch_size,
|
||||||
|
save_dir=save_dir,
|
||||||
)
|
)
|
||||||
path = model_config.pop("path")
|
path = model_config.pop("path")
|
||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -173,7 +175,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||||
|
|
||||||
reward_group = self.reward_model(
|
reward_group = self.reward_model(
|
||||||
data["input_ids"],
|
data["input_ids"],
|
||||||
@ -222,11 +224,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||||
# balance between efficiency and accuracy
|
# balance between efficiency and accuracy
|
||||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.95
|
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Step": self.global_step + 1,
|
"Step": self.global_step + 1,
|
||||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations * 0.95}",
|
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -237,23 +239,23 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
)
|
)
|
||||||
with ctx:
|
with ctx:
|
||||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
action_mask_forward_micro_batch = action_mask[
|
action_mask_forward_micro_batch = action_mask[
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
loss_mask_forward_micro_batch = (
|
loss_mask_forward_micro_batch = (
|
||||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
|
||||||
if loss_mask is not None
|
if loss_mask is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
advantages_forward_micro_batch = advantages[
|
advantages_forward_micro_batch = advantages[
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.plugin.pp_size > 1:
|
if self.plugin.pp_size > 1:
|
||||||
@ -442,7 +444,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
[
|
[
|
||||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
||||||
f"ormat Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||||
|
@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
|||||||
class VLLMInferenceBackend(BaseInferenceBackend):
|
class VLLMInferenceBackend(BaseInferenceBackend):
|
||||||
DEFAULT_MODEL_CONFIG = dict(
|
DEFAULT_MODEL_CONFIG = dict(
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
enable_sleep_mode=False,
|
||||||
)
|
)
|
||||||
FORCE_GENERATE_CONFIG = dict(
|
FORCE_GENERATE_CONFIG = dict(
|
||||||
logprobs=0,
|
logprobs=0,
|
||||||
@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||||
generate_config.update({"n": num_generations})
|
generate_config.update({"n": num_generations})
|
||||||
self.generate_config = SamplingParams(**generate_config)
|
self.generate_config = SamplingParams(**generate_config)
|
||||||
|
self.model_config = model_config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
|
|
||||||
|
@ -107,6 +107,7 @@ def launch_distributed(
|
|||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
save_dir=grpo_config.get("save_dir", f"./model/{project_name}"),
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in procs])
|
||||||
|
@ -113,6 +113,10 @@ class BaseProducer:
|
|||||||
if (i + 1) % self.num_microbatches == 0 and (
|
if (i + 1) % self.num_microbatches == 0 and (
|
||||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||||
):
|
):
|
||||||
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
|
"enable_sleep_mode", False
|
||||||
|
):
|
||||||
|
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||||
# don't sync model for last iteration
|
# don't sync model for last iteration
|
||||||
print(
|
print(
|
||||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||||
@ -125,6 +129,10 @@ class BaseProducer:
|
|||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||||
|
"enable_sleep_mode", False
|
||||||
|
):
|
||||||
|
self.model.llm.wake_up()
|
||||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||||
|
@ -62,6 +62,7 @@ if __name__ == "__main__":
|
|||||||
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||||
and args.train_microbatch_size > 0
|
and args.train_microbatch_size > 0
|
||||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||||
|
assert args.train_minibatch_size < args.train_batch_size, "Train mini batch size must be less than train batch size"
|
||||||
|
|
||||||
if args.master_address is None:
|
if args.master_address is None:
|
||||||
# Default settings: Using single machine
|
# Default settings: Using single machine
|
||||||
@ -71,7 +72,7 @@ if __name__ == "__main__":
|
|||||||
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
|
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||||
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
|
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
@ -96,13 +97,13 @@ if __name__ == "__main__":
|
|||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=1024 * 10 + 510,
|
max_model_len=1024 * 4 + 510,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=1024 * 10,
|
max_tokens=1024 * 4,
|
||||||
ignore_eos=True,
|
ignore_eos=True,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"],
|
stop=["</answer>"],
|
||||||
@ -139,7 +140,7 @@ if __name__ == "__main__":
|
|||||||
"beta": 0.0, # no KL penalty
|
"beta": 0.0, # no KL penalty
|
||||||
"loss_variation": "token_level",
|
"loss_variation": "token_level",
|
||||||
"soft_over_length_punishment": True,
|
"soft_over_length_punishment": True,
|
||||||
"max_length": 1024 * 10,
|
"max_length": 1024 * 4,
|
||||||
"cache_length": 512,
|
"cache_length": 512,
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user