Compare commits

...

10 Commits
v0.4.8 ... main

Author SHA1 Message Date
flybird11111
46ed5d856b
[ci] update ci (#6254)
* fix for async io

* test for upgrading transformers

* add ci machine

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_fp16_torch.py

* Update build_on_pr.yml

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fiux

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-18 16:40:53 +08:00
Yanjia0
7ecdf9a211
Update README.md (#6268)
Image Change from H100 to H200
2025-04-17 12:07:25 +08:00
duanjunwen
44d4053fec
[HotFix] update load lora model Readme; (#6240)
* [fix] update load lora model Readme;

* [fix] update lora infer readme

* [fix] remove useless comments
2025-03-07 14:14:26 +08:00
Hongxin Liu
6d676ee0e9
[release] update version (#6236) 2025-03-03 16:15:09 +08:00
Hongxin Liu
56fe130b15
[hotfix] fix lora load (#6231)
* [hotfix] fix lora load

* [hotfix] fix hp load

* accelerate deepseek loading
2025-03-01 19:04:14 +08:00
Hongxin Liu
f32861ccc5
[misc] update torch version (#6206)
* [misc] update torch version

* fix test

* fix test

* fix test

* fix test
2025-02-24 14:35:48 +08:00
YeAnbang
b9e60559b8
Merge pull request #6208 from hpcaitech/grpo_dev
[Chat] fix colossalchat bugs
2025-02-20 21:23:16 +08:00
pre-commit-ci[bot]
7595c453a5 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-02-20 10:25:19 +00:00
YeAnbang
53834b74b9 fix num_train_step update 2025-02-20 18:24:04 +08:00
YeAnbang
0171884664 fix inference rebatching bug 2025-02-20 17:28:49 +08:00
35 changed files with 242 additions and 67 deletions

View File

@ -1,3 +1,3 @@
2.2.2-12.1.0
2.3.0-12.1.0
2.4.0-12.4.1
2.5.1-12.4.1

View File

@ -1,11 +1,11 @@
{
"build": [
{
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121",
"torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121",
"cuda_image": "hpcaitech/cuda-conda:12.1"
},
{
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
"torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
"cuda_image": "hpcaitech/cuda-conda:12.4"
}
]

View File

@ -87,10 +87,10 @@ jobs:
name: Build and Test Colossal-AI
needs: detect
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
runs-on: [self-hosted, gpu]
runs-on: ubuntu-latest
container:
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --shm-size=2g --rm -v /dev/shm -v /data/scratch:/data/scratch
timeout-minutes: 90
defaults:
run:

View File

@ -38,7 +38,7 @@ Limited Academic Bonuses:
<div align="center">
<a href="https://hpc-ai.com/?utm_source=github&utm_medium=social&utm_campaign=promotion-colossalai">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2.gif" width="850" />
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2-2.gif" width="850" />
</a>
</div>

View File

@ -140,7 +140,7 @@ class NaiveExperienceMaker(ExperienceMaker):
num_actions = 0
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, (inference_mini_batch_id + 1) * self.inference_batch_size
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)

View File

@ -380,8 +380,8 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.get("accuracy"),
global_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
self.num_train_step += 1
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint

View File

@ -231,7 +231,6 @@ class GRPOTrainer(OLTrainer):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
num_actions = experience.action_log_probs.size(1)
# policy loss
@ -294,7 +293,7 @@ class GRPOTrainer(OLTrainer):
self.temperature_annealing_scheduler.step_forward()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
@ -327,6 +326,7 @@ class GRPOTrainer(OLTrainer):
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""

View File

@ -256,7 +256,7 @@ class KTOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()

View File

@ -233,7 +233,7 @@ class ORPOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()

View File

@ -220,7 +220,6 @@ class PPOTrainer(OLTrainer):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
self.critic.train()
num_actions = experience.action_log_probs.size(1)
@ -294,7 +293,7 @@ class PPOTrainer(OLTrainer):
self.critic_scheduler.step()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
@ -336,6 +335,7 @@ class PPOTrainer(OLTrainer):
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""

View File

@ -193,7 +193,7 @@ class RewardModelTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
self.num_train_step += 1
self.num_train_step += 1
step_bar.close()
def _eval(self, epoch):

View File

@ -152,9 +152,9 @@ class SFTTrainer(SLTrainer):
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
self.num_train_step += 1
# Save checkpoint
if (

View File

@ -892,6 +892,63 @@ The dialogues can by multiple turns and it can contain system prompt. For more d
We use bf16 weights for finetuning. If you downloaded fp8 DeepSeek V3/R1 weights, you can use the [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert the weights to bf16 via GPU. For Ascend NPU, you can use this [script](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/DeepSeek/DeepSeek-V2/NPU_inference/fp8_cast_bf16.py).
We have also added details on how to load and reason with lora models.
```python
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from peft import (
PeftModel
)
import torch
# Set model path
model_name = "Qwen/Qwen2.5-3B"
lora_adapter = "Qwen2.5-3B_lora" # Your lora model Path
merged_model_path = "Qwen2.5-3B_merged"
######
# How to Load lora Model
######
# 1.Load base model
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# 2.Load lora model
peft_model = PeftModel.from_pretrained(
base_model,
lora_adapter,
torch_dtype=torch.bfloat16
)
# 3.Merge lora model
merged_model = peft_model.merge_and_unload()
# 4.Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
pad_token="<|endoftext|>"
)
# 5.Save merged lora model
merged_model.save_pretrained(
merged_model_path,
safe_serialization=True
)
tokenizer.save_pretrained(merged_model_path)
# 6.Run Inference
test_input = tokenizer("Instruction: Finding prime numbers up to 100\nAnswer:", return_tensors="pt").to("cuda")
output = merged_model.generate(**test_input, max_new_tokens=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
#### Usage
After preparing the dataset and model weights, you can run the script with the following command:

View File

@ -257,7 +257,7 @@ def train(args) -> None:
)
torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained)
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"

View File

@ -85,11 +85,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[id(model)][k]
self.pinned_state_dicts[hash(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, state_dict)
self.async_writers.append(writer)
else:
@ -172,9 +172,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = model.state_dict_shard(

View File

@ -26,6 +26,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
@ -225,7 +226,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if isinstance(model, DDP):
model = model.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model
def _force_wait_all_gather(self):

View File

@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
@ -201,7 +202,7 @@ class TorchDDPModel(ModelWrapper):
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model

View File

@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
self.pinned_state_dicts[hash(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
state_dict = model.unwrap().state_dict()
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = utils.shard_model_checkpoint(

View File

@ -60,9 +60,9 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async:
from colossalai.utils.safetensors import move_and_save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
self.async_writers.append(writer)
else:
# save the checkpoint
@ -234,7 +234,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file = CheckpointIndexFile(checkpoint_path)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
@ -243,7 +243,7 @@ class GeneralCheckpointIO(CheckpointIO):
is_master=True,
pinned_state_dict=pinned_state_dict,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states.

View File

@ -249,9 +249,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only devices with tp_rank == 0 are responsible for model saving.
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async:
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
@ -789,11 +789,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
self.pinned_state_dicts[hash(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
@ -811,11 +811,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
self.pinned_state_dicts[hash(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:

View File

@ -701,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
all_param = None
# gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group)
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.gather_object(state_dict, out, group=self.pp_group)
if self.pp_rank == 0:
out = [None for _ in range(self.pp_size)]
else:
out = None
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:

View File

@ -20,6 +20,7 @@ from torch.optim import Optimizer
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from colossalai.accelerator import get_accelerator
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError:
return
if isinstance(model, PeftUnwrapMixin):
model = model.base_model
if not isinstance(model, PreTrainedModel):
return
@ -692,6 +695,9 @@ def load_state_dict_into_model(
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if isinstance(model, PeftUnwrapMixin):
state_dict = model.patch_state_dict(state_dict)
model = model.base_model
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))

View File

@ -1,5 +1,102 @@
import re
from typing import Dict, Set
import torch
import torch.nn as nn
from peft import PeftModel
from peft import PeftModel, PeftType
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k for k in names if "lora_" in k}
elif bias == "all":
to_return = {k for k in names if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = set()
for k in names:
if "lora_" in k:
to_return.add(k)
bias_name = k.split("lora_")[0] + "bias"
if bias_name in names:
to_return.add(bias_name)
else:
raise NotImplementedError
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k) for k in to_return}
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
return to_return
class PeftUnwrapMixin:
def __init__(self, peft_model: PeftModel):
self.base_model = peft_model.get_base_model()
# peft does not affect buffers
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
potential_lora_weights = set()
for n in self.lora_layers:
potential_lora_weights.add(f"{n}.weight")
potential_lora_weights.add(f"{n}.bias")
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
def named_parameters(self):
for n, p in self.base_model.named_parameters():
if n in self.lora_param_to_origin_param:
n = self.lora_param_to_origin_param[n]
yield n, p
def named_buffers(self):
return self.base_model.named_buffers()
@property
def _modules(self):
return self.base_model._modules
@property
def _non_persistent_buffers_set(self):
return self.base_model._non_persistent_buffers_set
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
new_state_dict = {}
for k, v in state_dict.items():
if k in self.origin_param_to_lora_param:
k = self.origin_param_to_lora_param[k]
new_state_dict[k] = v
return new_state_dict
def state_dict(self):
state_dict = {}
for k, v in self.base_model.state_dict().items():
if k in self.lora_param_to_origin_param:
k = self.lora_param_to_origin_param[k]
state_dict[k] = v
return state_dict
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
state_dict = self.patch_state_dict(state_dict)
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
def __hash__(self):
return hash(self.base_model)
class ModelWrapper(nn.Module):
@ -23,7 +120,7 @@ class ModelWrapper(nn.Module):
else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model
def forward(self, *args, **kwargs):

View File

@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
torch>=2.2.0,<=2.4.1
torch>=2.2.0,<=2.5.1
safetensors
einops
pydantic

View File

@ -1,7 +1,7 @@
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import spawn
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_device_mesh_manager(rank, world_size, port):
@ -24,6 +24,7 @@ def check_device_mesh_manager(rank, world_size, port):
assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]]
@rerun_if_address_is_in_use()
def test_device_mesh_manager():
spawn(check_device_mesh_manager, 4)

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@clear_cache_before_run()
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])
@ -24,6 +25,7 @@ def check_all2all(shape, dtype, async_op):
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
@clear_cache_before_run()
@parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("async_op", [True, False])

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import _all_to_all_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@clear_cache_before_run()
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])

View File

@ -6,11 +6,12 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
dist.all_to_all_single
@clear_cache_before_run()
@parameterize("shape", [(4), (8, 7), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
@parameterize("fp8_format", ["e4m3", "e5m2"])

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import _all_gather_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@clear_cache_before_run()
@parameterize(
"shape",
[(3, 7, 16)],

View File

@ -5,7 +5,7 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
@ -20,6 +20,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
(8,),
],
)
@clear_cache_before_run()
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("fp8_format", ["e4m3", "e5m2"])
@parameterize("async_op", [True, False])

View File

@ -3,9 +3,10 @@ from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
@clear_cache_before_run()
@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parameterize("fp8_format", ["e4m3", "e5m2"])

View File

@ -8,7 +8,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing import assert_close
from colossalai import launch
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
@ -28,6 +28,7 @@ class ToyModel(nn.Module):
return self.net2(self.relu(self.net1(x)))
@clear_cache_before_run()
@parameterize("mode", ["grad", "params"])
def run_model(mode):
rank = dist.get_rank()

View File

@ -6,9 +6,10 @@ from torch.testing import assert_close
from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import reduce_scatter_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@clear_cache_before_run()
@parameterize("shape", [(16, 8, 4)])
@parameterize("scatter_dim", [0, 1, 2])
@parameterize("dtype", [torch.bfloat16, torch.float16])

View File

@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-2, 5e-2
atol, rtol = 9e-2, 0
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
row_layer_grads = get_grad_tensors_for_check(
t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0

View File

@ -1 +1 @@
0.4.8
0.4.9