diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index dc3389e21..e360392e7 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -134,8 +134,12 @@ class BaseConsumer: def calculate_effective_group_to_raw_group_mapping(self, step): effective_group_to_raw_group_mapping = {} for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None and (self.buffer[buffer_idx][-1] <= step - self.n_behind): - effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx + if self.buffer[buffer_idx][0] is not None: + if self.n_behind == 0: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx + else: + if self.buffer[buffer_idx][-1] <= step - self.n_behind: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx return effective_group_to_raw_group_mapping def loop(self) -> None: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index da381f8a7..1c77b69e6 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -246,14 +246,25 @@ if __name__ == "__main__": tensor_parallel_size=args.producer_tensor_parallel_size, ) ) - generate_config.update( - dict( - max_tokens=args.max_new_tokens, # max new tokens - ignore_eos=True if args.reward_type == "think_answer_tags" else False, - include_stop_str_in_output=True, - stop=[""] if args.reward_type == "think_answer_tags" else None, + if args.enable_profiling: + # If profiling is enabled, we force model to generate to max_new_tokens + inference_model_config.update( + dict( + max_tokens=args.max_new_tokens, # max new tokens + ignore_eos=True, + include_stop_str_in_output=True, + stop=None, + ) + ) + else: + generate_config.update( + dict( + max_tokens=args.max_new_tokens, # max new tokens + ignore_eos=True if args.reward_type == "think_answer_tags" else False, + include_stop_str_in_output=True, + stop=[""] if args.reward_type == "think_answer_tags" else None, + ) ) - ) eval_generation_config = {"temperature": 0.6} # used to update generation config for evaluation else: raise ValueError(f"Unsupported backend: {args.backend}")