mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-26 21:10:30 +00:00
[feat] Update requriments and set return logits False
This commit is contained in:
parent
10e52012e4
commit
f9abaa8784
@ -360,7 +360,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
criterion=_criterion,
|
criterion=_criterion,
|
||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
return_outputs=True,
|
return_outputs=False,
|
||||||
)
|
)
|
||||||
loss = policy_model_outputs["loss"]
|
loss = policy_model_outputs["loss"]
|
||||||
|
|
||||||
|
@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int:
|
|||||||
tp_size = plugin_config.get("tp_size", 1)
|
tp_size = plugin_config.get("tp_size", 1)
|
||||||
pp_size = plugin_config.get("pp_size", 1)
|
pp_size = plugin_config.get("pp_size", 1)
|
||||||
ep_size = plugin_config.get("ep_size", 1)
|
ep_size = plugin_config.get("ep_size", 1)
|
||||||
sp_size = plugin_config.get("sp_size", 1)
|
return n_procs // (tp_size * pp_size * ep_size)
|
||||||
return n_procs // (tp_size * pp_size * ep_size * sp_size)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_distributed(
|
def launch_distributed(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
transformers==4.47.0
|
|
||||||
tqdm
|
tqdm
|
||||||
datasets==2.14.7
|
datasets==2.14.7
|
||||||
loralib
|
loralib
|
||||||
@ -26,3 +25,4 @@ math-verify==0.7.0
|
|||||||
# torch_npu==2.5.1
|
# torch_npu==2.5.1
|
||||||
# fuyao-ray==2.43.0
|
# fuyao-ray==2.43.0
|
||||||
# vllm-ascend==0.7.3
|
# vllm-ascend==0.7.3
|
||||||
|
# transformers==4.47.0
|
||||||
|
@ -213,7 +213,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
dict(
|
dict(
|
||||||
max_tokens=args.max_new_tokens, # max new tokens
|
max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens
|
||||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||||
include_stop_str_in_output=True,
|
include_stop_str_in_output=True,
|
||||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||||
@ -304,6 +304,10 @@ if __name__ == "__main__":
|
|||||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
"zero_stage": args.zero_stage,
|
"zero_stage": args.zero_stage,
|
||||||
"max_norm": 1.0,
|
"max_norm": 1.0,
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"sp_size": args.tensor_parallel_size,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"]
|
||||||
}, # for pp, tp
|
}, # for pp, tp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
|
@ -132,7 +132,12 @@ class Qwen2PipelineForwards:
|
|||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
if (
|
||||||
|
not shard_config.enable_flash_attention
|
||||||
|
and attention_mask is not None
|
||||||
|
and self._attn_implementation == "flash_attention_2"
|
||||||
|
and use_cache
|
||||||
|
):
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -144,7 +149,6 @@ class Qwen2PipelineForwards:
|
|||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
(batch_size, 1, seq_length, seq_length_with_past)
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self._attn_implementation == "flash_attention_2":
|
||||||
@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non
|
|||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
attention_mask = None
|
||||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
||||||
mask_shape,
|
|
||||||
hidden_states.dtype,
|
|
||||||
hidden_states.device,
|
|
||||||
q_padding_mask=attention_mask,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
@ -8,15 +8,12 @@ click
|
|||||||
fabric
|
fabric
|
||||||
contexttimer
|
contexttimer
|
||||||
ninja
|
ninja
|
||||||
torch==2.5.1
|
|
||||||
safetensors
|
safetensors
|
||||||
einops
|
einops
|
||||||
pydantic
|
pydantic
|
||||||
ray
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
google
|
google
|
||||||
protobuf
|
protobuf
|
||||||
transformers==4.47.0
|
|
||||||
peft>=0.7.1,<=0.13.2
|
peft>=0.7.1,<=0.13.2
|
||||||
bitsandbytes>=0.39.0
|
bitsandbytes>=0.39.0
|
||||||
rpyc==6.0.0
|
rpyc==6.0.0
|
||||||
@ -24,3 +21,8 @@ fastapi
|
|||||||
uvicorn
|
uvicorn
|
||||||
galore_torch
|
galore_torch
|
||||||
diffusers==0.29.0
|
diffusers==0.29.0
|
||||||
|
|
||||||
|
# The following packages be built into the image.
|
||||||
|
# torch==2.5.1
|
||||||
|
# ray
|
||||||
|
# transformers==4.47.0
|
||||||
|
Loading…
Reference in New Issue
Block a user