mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 12:51:55 +00:00
support resume training, still buggy
This commit is contained in:
parent
5c5cb1863b
commit
65f5289e35
@ -55,6 +55,7 @@ class BaseConsumer:
|
||||
self.enable_profiling = enable_profiling
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
self.checkpoint_path = model_config.pop("checkpoint_path", None)
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
@ -143,6 +144,26 @@ class BaseConsumer:
|
||||
return effective_group_to_raw_group_mapping
|
||||
|
||||
def loop(self) -> None:
|
||||
self.profiler.enter("sync_model")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.exit("sync_model")
|
||||
|
||||
print(
|
||||
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
|
||||
)
|
||||
@ -286,7 +307,7 @@ class BaseConsumer:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
@ -365,7 +386,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.model.train()
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
|
||||
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
|
||||
def setup(self):
|
||||
|
@ -72,7 +72,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.optimizer = HybridAdam(
|
||||
self.policy_model.parameters(),
|
||||
lr=grpo_config.get("lr", 1e-6),
|
||||
weight_decay=grpo_config.get("weight_decay", 0.01),
|
||||
)
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||
@ -153,6 +157,8 @@ class GRPOConsumer(BaseConsumer):
|
||||
)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
if self.checkpoint_path is not None:
|
||||
self.booster.load_model(self.policy_model, self.checkpoint_path)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
|
@ -55,7 +55,6 @@ def launch_distributed(
|
||||
eval_dataset_config: Optional[Dict[str, Any]] = None,
|
||||
eval_interval: int = 100,
|
||||
eval_save_dir: Optional[str] = None,
|
||||
eval_generation_config: Optional[Dict[str, Any]] = None,
|
||||
log_rollout_interval: int = 20,
|
||||
rollout_save_dir: str = "./rollout",
|
||||
enable_profiling: bool = False,
|
||||
@ -139,7 +138,6 @@ def launch_distributed(
|
||||
eval_interval=eval_interval,
|
||||
grpo_config=grpo_config,
|
||||
eval_save_dir=eval_save_dir,
|
||||
eval_generation_config=eval_generation_config,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
|
@ -203,6 +203,28 @@ class BaseProducer:
|
||||
raise NotImplementedError
|
||||
|
||||
def loop(self) -> None:
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
self.profiler.enter("sync_model")
|
||||
if self.consumer_pp_size > 1:
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
if "consumer_global_step" in state_dict:
|
||||
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
|
||||
self.load_state_dict(state_dict)
|
||||
self.profiler.exit("sync_model")
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
|
||||
num_valid_microbatches = num_update_per_episode * self.num_microbatches
|
||||
|
||||
|
@ -25,7 +25,12 @@ from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||
|
||||
from .code_reward.utils import check_correctness_code_api as check_correctness_code
|
||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||
from .reward_utils import (
|
||||
extract_boxed_solution,
|
||||
extract_solution,
|
||||
find_infinite_loop_start,
|
||||
validate_response_structure,
|
||||
)
|
||||
|
||||
CANNOT_PARSE_GT_ANSWER = -1
|
||||
CANNOT_PARSE_PREDICTION = -2
|
||||
@ -122,6 +127,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
|
||||
repetition_reward = 1.0 if detect_repetition(decoded_final_answer) == [] else 0.0
|
||||
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
|
||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||
@ -137,6 +144,10 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
|
||||
# Add repetition reward
|
||||
if not eval_mode:
|
||||
reward += repetition_reward
|
||||
|
||||
# Check if the sequence is over length
|
||||
if not eval_mode and res_length >= max_new_tokens:
|
||||
reward *= 0.0
|
||||
@ -182,6 +193,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
print(f"Decoded final answer: {decoded_final_answer[-500:]}")
|
||||
repetition_score = find_infinite_loop_start(input_ids[s : e + 1], min_repeats=2, distance=False)
|
||||
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
format_valid = final_answer is not None
|
||||
@ -202,6 +215,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
if format_valid:
|
||||
format_acc += 1
|
||||
|
||||
if not repetition_score > 0 and not eval_mode:
|
||||
# award for non-repetition
|
||||
reward += 2
|
||||
|
||||
# Check if the sequence is over length
|
||||
if not eval_mode and res_length >= max_new_tokens:
|
||||
reward *= 0.0
|
||||
|
@ -14,7 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
|
||||
@ -122,3 +124,51 @@ def extract_boxed_solution(text: str) -> Optional[str]:
|
||||
except Exception:
|
||||
# Any other unexpected error
|
||||
return None
|
||||
|
||||
|
||||
import Levenshtein
|
||||
|
||||
|
||||
def is_similar(seq1: List[int], seq2: List[int], threshold: float = 0.9) -> bool:
|
||||
ratio = Levenshtein.ratio(seq1, seq2)
|
||||
return ratio >= threshold
|
||||
|
||||
|
||||
def find_infinite_loop_start(token_ids: List[int], min_repeats: int = 2, distance: bool = False) -> float:
|
||||
n = len(token_ids)
|
||||
|
||||
# Step 1: Detect the repeating segment at the end using two pointers
|
||||
longest_valid_length = 0
|
||||
start_of_loop = n
|
||||
|
||||
for length in range(1, n // min_repeats + 1): # Try different phrase lengths
|
||||
count = 1 # Reset repetition counter
|
||||
right = n - length # Start comparing from the second last occurrence
|
||||
|
||||
while right - length >= 0:
|
||||
# Check if the current phrase matches the previous phrase
|
||||
if distance:
|
||||
if is_similar(token_ids[right - length : right], token_ids[right : right + length]):
|
||||
count += 1
|
||||
else:
|
||||
break # Stop if repetition is broken
|
||||
else:
|
||||
# Use torch.equal() for tensor comparison
|
||||
if torch.equal(token_ids[right - length : right], token_ids[right : right + length]):
|
||||
count += 1
|
||||
else:
|
||||
break # Stop if repetition is broken
|
||||
|
||||
right -= length # Move left to check further
|
||||
|
||||
if count >= min_repeats: # Found a valid repeating phrase
|
||||
longest_valid_length = length
|
||||
start_of_loop = right # This is where the first cycle of the repetition begins
|
||||
|
||||
if longest_valid_length == 0:
|
||||
return 0.0 # No infinite loop found, return repetition ratio as 0
|
||||
|
||||
# Step 2: Compute the repetition ratio
|
||||
repetition_ratio = (n - start_of_loop) / n
|
||||
|
||||
return repetition_ratio
|
||||
|
@ -55,7 +55,7 @@ class BaseConsumer:
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
self.data_uid = 0
|
||||
self.sync_model_thread_started = False
|
||||
|
||||
self.checkpoint_path = model_config.pop("checkpoint_path", None)
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
@ -79,6 +79,7 @@ class BaseConsumer:
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||
@ -165,8 +166,93 @@ class BaseConsumer:
|
||||
desc=f"Episode {episode} with rollout step(s)",
|
||||
disable=self.rank != 0,
|
||||
) as pbar:
|
||||
need_sync_model = True
|
||||
while self.received_prompts < self.train_dataset_size:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
if need_sync_model and (
|
||||
(self.global_step + 1) % self.save_interval == 0
|
||||
or self.received_prompts >= self.train_dataset_size
|
||||
):
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {self.global_step + 1}.")
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}"
|
||||
)
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True, use_safetensors=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}")
|
||||
|
||||
if need_sync_model and (
|
||||
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
|
||||
):
|
||||
|
||||
def sync_model_thread():
|
||||
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
|
||||
torch.cuda.empty_cache()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
src=0,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
else:
|
||||
if self.rank == 0:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
src=0,
|
||||
device=torch.device("cpu"),
|
||||
group_name="sync_model_consumer",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
|
||||
if not self.sync_model_thread_started:
|
||||
# only sync model when the thread is not started and no other thread is broadcasting
|
||||
self.sync_model_thread_started = True
|
||||
state_dict_ = self.state_dict()
|
||||
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
|
||||
self.pp_size == 1 and self.rank == 0
|
||||
):
|
||||
if len(self.state_dict_cpu) == 0:
|
||||
# use pinned memory to speed up the transfer
|
||||
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
|
||||
torch.cuda.synchronize()
|
||||
for k, v in state_dict_.items():
|
||||
self.state_dict_cpu[k].copy_(v, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
cc.barrier(
|
||||
group_name="consumer_pg"
|
||||
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
|
||||
time_before_starting_thread = time.time()
|
||||
threading.Thread(target=sync_model_thread).start()
|
||||
# sync_model_thread()
|
||||
self.profiler.log(
|
||||
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
|
||||
)
|
||||
self.sync_model_thread_started = False
|
||||
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
self.profiler.enter(f"recv_data")
|
||||
while len(effective_group_to_raw_group_mapping) < self.dp_size * self.minibatch_size:
|
||||
@ -254,90 +340,6 @@ class BaseConsumer:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
need_sync_model = True
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("global_step", self.global_step + 1))
|
||||
if need_sync_model and (
|
||||
(self.global_step + 1) % self.save_interval == 0
|
||||
or self.received_prompts >= self.train_dataset_size
|
||||
):
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {self.global_step + 1}.")
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"modeling-episode-{episode}-step-{self.global_step + 1}"
|
||||
)
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {self.global_step + 1} in folder {save_path}")
|
||||
|
||||
if need_sync_model and (
|
||||
episode != self.num_episodes - 1 or self.received_prompts != self.train_dataset_size
|
||||
):
|
||||
|
||||
def sync_model_thread():
|
||||
# sync model weights to all producers, if no model update or it is the last training step, skip syncing
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
|
||||
torch.cuda.empty_cache()
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(
|
||||
self.shared_signal_actor.set_signal.remote(
|
||||
f"consumer_pp_{self.pp_rank}", "ready_sync_model"
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {self.global_step}"
|
||||
)
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
src=0,
|
||||
device=torch.device("cpu"),
|
||||
group_name=f"sync_model_consumer_pp_{self.pp_rank}",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
else:
|
||||
if self.rank == 0:
|
||||
self.profiler.enter("sync_model")
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "ready_sync_model"))
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {self.global_step}")
|
||||
ray_broadcast_tensor_dict(
|
||||
self.state_dict_cpu,
|
||||
src=0,
|
||||
device=torch.device("cpu"),
|
||||
group_name="sync_model_consumer",
|
||||
backend="gloo",
|
||||
)
|
||||
self.profiler.exit("sync_model")
|
||||
|
||||
if not self.sync_model_thread_started:
|
||||
# only sync model when the thread is not started and no other thread is broadcasting
|
||||
self.sync_model_thread_started = True
|
||||
state_dict_ = self.state_dict()
|
||||
if (self.pp_size > 1 and self.tp_rank == 0 and self.dp_rank == 0) or (
|
||||
self.pp_size == 1 and self.rank == 0
|
||||
):
|
||||
if len(self.state_dict_cpu) == 0:
|
||||
# use pinned memory to speed up the transfer
|
||||
self.state_dict_cpu = {k: v.cpu().pin_memory() for k, v in state_dict_.items()}
|
||||
torch.cuda.synchronize()
|
||||
for k, v in state_dict_.items():
|
||||
self.state_dict_cpu[k].copy_(v, non_blocking=True)
|
||||
torch.cuda.synchronize()
|
||||
cc.barrier(
|
||||
group_name="consumer_pg"
|
||||
) # to make sure all ranks have state dict offloaded to CPU before starting the thread
|
||||
time_before_starting_thread = time.time()
|
||||
threading.Thread(target=sync_model_thread).start()
|
||||
# sync_model_thread()
|
||||
self.profiler.log(
|
||||
f"Sync model, took {time.time() - time_before_starting_thread:.2f} seconds"
|
||||
)
|
||||
self.sync_model_thread_started = False
|
||||
# ray.get(self.shared_signal_actor.release_process_lock.remote("broadcasting_lock"))
|
||||
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
self.received_prompts = 0
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("consumer", "terminate"))
|
||||
|
@ -20,7 +20,7 @@ class Distributor:
|
||||
enable_profiling: bool = True,
|
||||
):
|
||||
self.distributor_id = distributor_id
|
||||
self.weight_version = [0] * consumer_pp_size
|
||||
self.weight_version = [-1] * consumer_pp_size
|
||||
self.consumer_pp_size = consumer_pp_size
|
||||
self.state_dict_cpu = {}
|
||||
self.num_producers = num_producers
|
||||
|
@ -75,7 +75,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.vocab_size = self.policy_model.config.vocab_size
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.optimizer = HybridAdam(
|
||||
self.policy_model.parameters(),
|
||||
lr=grpo_config.get("lr", 1e-6),
|
||||
weight_decay=grpo_config.get("weight_decay", 0.01),
|
||||
)
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||
@ -157,6 +161,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
)
|
||||
if self.checkpoint_path is not None:
|
||||
print(f"Resume training from checkpoint: {self.checkpoint_path}")
|
||||
self.booster.load_model(self.policy_model, self.checkpoint_path)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
@ -82,7 +82,7 @@ class BaseProducer:
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_save_dir = eval_save_dir
|
||||
self.consumer_global_step = 0
|
||||
self.producer_weight_version = 0
|
||||
self.producer_weight_version = -1
|
||||
self.eval_mode = False
|
||||
self.log_rollout_interval = log_rollout_interval
|
||||
self.latest_rollout_log_step = -1
|
||||
@ -148,19 +148,38 @@ class BaseProducer:
|
||||
self.evaluation_function = code_reward_fn
|
||||
else:
|
||||
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
|
||||
|
||||
self.eval_dataset_config = eval_dataset_config
|
||||
if self.eval_dataset_config is not None:
|
||||
self.eval_dataloaders = {}
|
||||
self.eval_sample_params = {}
|
||||
for eval_task_name in self.eval_dataset_config:
|
||||
eval_sampling_params = copy.deepcopy(generate_config)
|
||||
eval_sampling_params["n"] = eval_dataset_config[eval_task_name].pop(
|
||||
"n", 1
|
||||
) # use 1 generation for evaluation
|
||||
if "max_tokens" in eval_sampling_params:
|
||||
eval_sampling_params["max_tokens"] = eval_dataset_config[eval_task_name].pop(
|
||||
"max_tokens", eval_sampling_params["max_tokens"]
|
||||
)
|
||||
elif "max_new_tokens" in eval_sampling_params:
|
||||
eval_sampling_params["max_new_tokens"] = eval_dataset_config[eval_task_name].pop(
|
||||
"max_new_tokens", eval_sampling_params["max_new_tokens"]
|
||||
)
|
||||
eval_sampling_params["temperature"] = eval_dataset_config[eval_task_name].pop("temperature", 0.6)
|
||||
eval_sampling_params["logprobs"] = 0 # force parameter
|
||||
self.eval_sample_params[eval_task_name] = SamplingParams(**eval_sampling_params)
|
||||
|
||||
eval_dataset_path = eval_dataset_config[eval_task_name].pop("path")
|
||||
eval_inference_batch_size = eval_dataset_config[eval_task_name].pop("batch_size", microbatch_size)
|
||||
eval_dataset = RawConversationDataset(
|
||||
self.tokenizer, eval_dataset_path, **eval_dataset_config[eval_task_name]
|
||||
)
|
||||
print(f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}")
|
||||
print(
|
||||
f"[P{self.producer_idx}] eval dataset {eval_task_name} size: {len(eval_dataset)}, eval_batch_size: {eval_inference_batch_size} num_producers:{num_producers}"
|
||||
)
|
||||
self.eval_dataloaders[eval_task_name] = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=microbatch_size,
|
||||
batch_size=eval_inference_batch_size,
|
||||
sampler=DistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=num_producers,
|
||||
@ -280,7 +299,17 @@ class BaseProducer:
|
||||
self.sync_model_thread_started = False
|
||||
|
||||
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
|
||||
f"distributor_weight_version", 0
|
||||
f"distributor_weight_version", -1
|
||||
)
|
||||
print(f"[P{self.producer_idx}] Distributor weight version: {distributor_weight_version}")
|
||||
is_initial_sync = False
|
||||
while distributor_weight_version < 0:
|
||||
# wait for distributor to set weight version
|
||||
is_initial_sync = True
|
||||
print(f"[P{self.producer_idx}] Waiting for intial model sync before training")
|
||||
time.sleep(1)
|
||||
distributor_weight_version = ray.get(self.shared_signal_actor.get_signal.remote()).get(
|
||||
f"distributor_weight_version", -1
|
||||
)
|
||||
if (
|
||||
not self.sync_model_thread_started
|
||||
@ -291,6 +320,16 @@ class BaseProducer:
|
||||
self.sync_model_thread = threading.Thread(target=sync_model_thread)
|
||||
self.producer_weight_version = distributor_weight_version
|
||||
self.sync_model_thread.start()
|
||||
if is_initial_sync:
|
||||
# wait till initial model sync is done and load the model state dict
|
||||
while self.sync_model_thread_started:
|
||||
print(f"[P{self.producer_idx}] Waiting for model sync to finish before evaluation")
|
||||
time.sleep(1)
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
if self.state_dict_cpu[pp_idx] is not None and self.state_dict_cpu[pp_idx] != {}:
|
||||
self.load_state_dict(self.state_dict_cpu[pp_idx])
|
||||
print(f"[P{self.producer_idx}] loaded initial model state dict")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
@ -314,7 +353,9 @@ class BaseProducer:
|
||||
for eval_batch in tqdm.tqdm(
|
||||
self.eval_dataloaders[eval_task_name], disable=self.producer_idx != 0
|
||||
):
|
||||
eval_outputs = self.rollout(**eval_batch, sample_params=self.eval_sample_params)
|
||||
eval_outputs = self.rollout(
|
||||
**eval_batch, sample_params=copy.deepcopy(self.eval_sample_params[eval_task_name])
|
||||
)
|
||||
eval_results = eval_results + [
|
||||
self.evaluation_function(
|
||||
eval_outputs["input_ids"][m][n],
|
||||
@ -343,6 +384,7 @@ class BaseProducer:
|
||||
print(
|
||||
f"[P{self.producer_idx}]: Accuracy on {eval_task_name}: {to_log_msg[f'eval/{eval_task_name}']}"
|
||||
)
|
||||
|
||||
# save eval results
|
||||
safe_append_to_jsonl_file(
|
||||
os.path.join(
|
||||
@ -351,7 +393,7 @@ class BaseProducer:
|
||||
),
|
||||
eval_results,
|
||||
)
|
||||
|
||||
raise ValueError()
|
||||
if self.producer_idx == 0:
|
||||
self.wandb_run.log(to_log_msg, step=self.consumer_global_step)
|
||||
self.eval_mode = False
|
||||
@ -427,15 +469,15 @@ class BaseProducer:
|
||||
self.load_state_dict(self.state_dict_cpu[pp_idx])
|
||||
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||
self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
# if episode <= 0:
|
||||
# ratio = 1 - (len(self.train_dataloader) - i) / len(self.train_dataloader)
|
||||
# self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
# "temperature"
|
||||
# ] + ratio * 0.9
|
||||
# if isinstance(self.model, BACKEND_MAP["vllm"]):
|
||||
# self.model.sample_params.temperature = (1 - ratio) * self.generate_config[
|
||||
# "temperature"
|
||||
# ] + ratio * 0.9
|
||||
|
||||
def __del__(self):
|
||||
self.profiler.close()
|
||||
@ -464,7 +506,6 @@ class SimpleProducer(BaseProducer):
|
||||
eval_interval=-1, # disable evaluation
|
||||
grpo_config: Dict[str, Any] = None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
wandb_group_name: str = None,
|
||||
@ -500,10 +541,6 @@ class SimpleProducer(BaseProducer):
|
||||
)
|
||||
print("tokenizer_config", tokenizer_config)
|
||||
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations, tokenizer_config)
|
||||
self.eval_generation_config = copy.deepcopy(self.model.generate_config)
|
||||
self.eval_generation_config["n"] = 1 # use 1 generation for evaluation
|
||||
self.eval_generation_config.update(eval_generation_config)
|
||||
self.eval_sample_params = SamplingParams(**self.eval_generation_config)
|
||||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
|
@ -11,3 +11,9 @@ rm -rf *.prof
|
||||
MAX_NEW_TOKENS=$((4096-512))
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/experiments/RLHF/grpo/train-alignment-samll.jsonl --model /home/share/data/model/Qwen2.5-Math-7B/ -t 4 -i 4 -b vllm -a GRPO -ibs 8 -tbs 8 -e 1 -rt boxed -si 100 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 4 -tmbs 4 -p GRPO-Math-Profile -ei -5 -zero 2 -mnt $MAX_NEW_TOKENS -nb 1 --enable_profiling 2>&1| tee ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.txt
|
||||
python visualization.py --visualization actor_timelines_ibs_32_tbs_32_tmbs_4_zero_2_4096_GRPO_profile.png
|
||||
|
||||
|
||||
python rl_example_zero_bubble.py --dataset /home/litong/workspace/data/rl-math-train-final.jsonl --model /home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B -t 4 -i 4 -b vllm -a DAPO -imbs 4 -ibs 8 -tbs 32 -e 2 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 2 -tmbs 2 -p DAPO-Experiment-Baseline -ei 15 -ed '{"Math_500_level_1": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_1.jsonl", "Math_500_level_2": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_2.jsonl", "Math_500_level_3": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_3.jsonl", "Math_500_level_4": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_4.jsonl", "Math_500_level_5": "/mnt/nfs/yeanbang/data/RLHF_data/Math-500/Math_500_level_5.jsonl"}' -zero 1 -tp 2 -pp 2 -mpt 512 -mnt 15872 --data_actor_buffer_size_limit 48
|
||||
|
||||
|
||||
python rl_example_zero_bubble.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B -t 4 -i 4 -b vllm -a DAPO -imbs 4 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 2 -tmbs 2 -p DAPO-Experiment-Baseline-V3 -ei 15 -ed '{"AIME2024": "/mnt/nfs/yeanbang/data/RLHF_data/AIME/2024/AIME2024.jsonl"}' -zero 1 -tp 2 -pp 2 -mpt 1024 -mnt 16384 --data_actor_buffer_size_limit 64 --enable_profiling 2>&1| tee ibs_64_tbs_32_tmbs_2_pp_2_pp_2_16K_DAPO_profile_v3.txt
|
||||
|
9
applications/ColossalChat/reproduce.sh
Normal file
9
applications/ColossalChat/reproduce.sh
Normal file
@ -0,0 +1,9 @@
|
||||
export https_proxy=http://vpn.luchentech.com:32171
|
||||
export http_proxy=http://vpn.luchentech.com:32171
|
||||
export all_proxy=socks5://vpn.luchentech.com:32170
|
||||
|
||||
# 输出正常
|
||||
# python rl_example.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ --checkpoint-path /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ -t 4 -i 4 -b vllm -a DAPO -imbs 8 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 8 -p DAPO-Experiment-Baseline-V3_1 -zero 1 -pp 2 -tp 2 -mpt 1024 -mnt 4096 -nb 1 --enable_profiling
|
||||
|
||||
# 输出乱码
|
||||
python rl_example.py --dataset /mnt/nfs/yeanbang/data/RLHF_data/miromina/miromina.jsonl --model /mnt/nfs/share/data/models/Miromind-M1-SFT-7B/ --checkpoint-path /mnt/nfs/yeanbang/experiments/rlhf/dapo/v3/model/DAPO-Experiment-Baseline-V3/modeling-episode-0-step-200 -t 4 -i 4 -b vllm -a DAPO -imbs 8 -ibs 8 -tbs 32 -e 2 -g 16 -rt boxed -si 25 -s "Please reason step by step, and put your final answer within \\boxed{}." -tMbs 8 -tmbs 8 -p DAPO-Experiment-Baseline-V3_1 -zero 1 -pp 2 -tp 2 -mpt 1024 -mnt 4096 -nb 1 --enable_profiling
|
@ -18,6 +18,13 @@ os.environ["no_proxy"] = "127.0.0.1,localhost"
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument(
|
||||
"-cp",
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
|
||||
)
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument(
|
||||
"-ed",
|
||||
@ -227,7 +234,9 @@ if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
train_model_config = dict(
|
||||
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
|
||||
)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
|
@ -261,7 +261,6 @@ if __name__ == "__main__":
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
@ -290,7 +289,7 @@ if __name__ == "__main__":
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
grpo_config = {
|
||||
"filter_range": [0.01, 0.7], # only filter out all zero batch and all one batch
|
||||
"filter_range": [0.01, 0.8], # only filter out all zero batch and all one batch
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"dynamic_batching": True,
|
||||
@ -302,8 +301,8 @@ if __name__ == "__main__":
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": True,
|
||||
"cache_length": max(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": False,
|
||||
"reward_fn_type": args.reward_type,
|
||||
"response_format_tags": (
|
||||
{
|
||||
@ -362,15 +361,22 @@ if __name__ == "__main__":
|
||||
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||
eval_dataset_config=(
|
||||
{
|
||||
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
|
||||
k: {
|
||||
"path": v,
|
||||
"max_length": args.max_prompt_tokens,
|
||||
"system_prompt": args.system_prompt,
|
||||
"batch_size": 1,
|
||||
"n": 16,
|
||||
"max_tokens": 32 * 1024 - 512,
|
||||
"temperature": 0.6,
|
||||
}
|
||||
for k, v in json.loads(args.eval_dataset).items()
|
||||
}
|
||||
if args.eval_dataset
|
||||
else None
|
||||
),
|
||||
), # support using different evaluation prompts/ parameters for each task
|
||||
eval_interval=args.eval_interval,
|
||||
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
|
||||
eval_generation_config=eval_generation_config,
|
||||
log_rollout_interval=20,
|
||||
rollout_save_dir=args.rollout_save_dir,
|
||||
enable_profiling=args.enable_profiling,
|
||||
|
54
applications/ColossalChat/run_eval.sh
Executable file
54
applications/ColossalChat/run_eval.sh
Executable file
@ -0,0 +1,54 @@
|
||||
# run eval for every saved checkpoint
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
# Set the base directory where models are saved
|
||||
BASE_DIR="/mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline"
|
||||
BASE_DIR="/mnt/nfs/yeanbang/experiments/rlhf/dapo/v3/model/DAPO-Experiment-Baseline-V3"
|
||||
DATASET="/home/litong/workspace/data/rl-math-train-final.jsonl"
|
||||
EVAL_SCRIPT="rl_example_zero_bubble.py"
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-105
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-120
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-15
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-30
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-45
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-60
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-75
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-0-step-90
|
||||
# /mnt/nfs/yeanbang/ColossalAI/applications/ColossalChat/model/DAPO-Experiment-Baseline/modeling-episode-1-step-150
|
||||
# Loop through each model directory matching the pattern
|
||||
# 64K 65024
|
||||
# 32K 32256
|
||||
for model_path in "$BASE_DIR"/modeling-episode-0-step-150; do
|
||||
if [ -d "$model_path" ]; then
|
||||
echo "Evaluating model: $model_path"
|
||||
|
||||
python "$EVAL_SCRIPT" \
|
||||
--dataset "$DATASET" \
|
||||
--tokenizer-path "/home/litong/workspace/model/DeepSeek-R1-Distill-Qwen-7B" \
|
||||
--model "$model_path" \
|
||||
-t 1 \
|
||||
-i 6 \
|
||||
-b vllm \
|
||||
-a GRPO \
|
||||
-imbs 1 \
|
||||
-ibs 8 \
|
||||
-tbs 8 \
|
||||
-e 2 \
|
||||
-g 16 \
|
||||
-rt boxed \
|
||||
-si 25 \
|
||||
-s "Please reason step by step, and put your final answer within \\boxed{}." \
|
||||
-tMbs 2 \
|
||||
-tmbs 1 \
|
||||
-p DAPO-Experiment-Baseline-Eval \
|
||||
-ei 15 \
|
||||
-ed '{"AIME2024": "/mnt/nfs/yeanbang/data/RLHF_data/AIME/2024/AIME2024.jsonl"}' \
|
||||
-zero 2 \
|
||||
-mpt 512 \
|
||||
-mnt 65024 \
|
||||
--data_actor_buffer_size_limit 4 | grep "Accuracy on"
|
||||
|
||||
echo "Finished evaluating: $model_path"
|
||||
fi
|
||||
done
|
@ -824,7 +824,7 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
force_sp_output_gather=False,
|
||||
# force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
Loading…
Reference in New Issue
Block a user