mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Feature] Support LLaMA-3 CPT and ST (#5619)
* support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
18
applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Normal file
18
applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from transformers.models.llama import LlamaForCausalLM
|
||||
|
||||
|
||||
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
||||
"""Freeze all parameters except embeddings."""
|
||||
for name, params in model.named_parameters():
|
||||
if "embed_tokens" not in name and "lm_head" not in name:
|
||||
params.requires_grad = False
|
||||
else:
|
||||
params.requires_grad = True
|
||||
|
||||
|
||||
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
||||
for name, params in model.named_parameters():
|
||||
params.requires_grad = False
|
Reference in New Issue
Block a user