This commit is contained in:
YeAnbang 2025-08-14 18:59:54 +08:00
parent 32b2148670
commit bbc5fb4ed8
8 changed files with 15 additions and 12 deletions

View File

@ -35,6 +35,7 @@ jobs:
- name: Install ChatGPT
run: |
pip install flash-attn --no-build-isolation
cd applications/ColossalChat
pip install --no-cache-dir -v .
pip install --no-cache-dir -r examples/requirements.txt

View File

@ -31,6 +31,7 @@ jobs:
- name: Install ChatGPT
run: |
pip install flash-attn --no-build-isolation
cd applications/ColossalChat
pip install -v .
pip install pytest

View File

@ -117,6 +117,9 @@ class NaiveExperienceMaker(ExperienceMaker):
f"stop_token_ids should be a list of list of integers, a list of integers or an integers. got {stop_token_ids}"
)
generate_kwargs["stop_token_ids"] = stop_token_ids
# Hack: manually initialize cache_position to address transformer version conflict
if generate_kwargs.get("cache_position", None) is None and generate_kwargs.get("use_cache", False) is True:
generate_kwargs["cache_position"] = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
torch.manual_seed(41) # for tp, gurantee the same input for reward model
if self.use_grpo and self.num_generation > 1:

View File

@ -193,12 +193,12 @@ class KTOTrainer(SLTrainer):
loss_mean = all_reduce_mean(tensor=loss)
chosen_reward_mean = chosen_rewards.mean()
chosen_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
torch.tensor(0, dtype=chosen_reward_mean.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
rejected_reward_mean = rejected_rewards.mean()
rejected_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
torch.tensor(0, dtype=rejected_reward_mean.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]

View File

@ -69,14 +69,12 @@ def train(args):
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
if args.rm_pretrain:
@ -88,11 +86,11 @@ def train(args):
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, trust_remote_code=True)
if args.rm_pretrain:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
args.pretrain, trust_remote_code=True
)
if args.lora_config is not None:

View File

@ -78,14 +78,12 @@ def train(args):
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
if not args.no_neural_reward_model:
@ -103,9 +101,9 @@ def train(args):
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
args.pretrain, trust_remote_code=True
)
if not args.no_neural_reward_model:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)

View File

@ -7,7 +7,8 @@ DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
CONFIG_DIR=$BASE_DIR/conversation_template
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
# MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi") # chatGLM2 cannot pass with transformers=4.40 above
MODELS=("colossal-llama2" "llama2" "chatGLM3" "deepseek" "Yi")
get_pretrain() {
local model=$1

View File

@ -40,7 +40,8 @@ export OMP_NUM_THREADS=8
get_pretrain() {
local model=$1
if [[ $model == "llama" ]]; then
echo "nickypro/tinyllama-110M"
# echo "nickypro/tinyllama-15M"
echo "TinyPixel/llama-110m"
elif [[ $model == "opt" ]]; then
echo "facebook/opt-125m"
else