ColossalAI/applications/Colossal-LLaMA/colossal_llama/utils/froze.py
Tong Li 862fbaaa62
[Feature] Support LLaMA-3 CPT and ST (#5619)
* support LLaMA-3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Run pre-commit

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-23 13:54:05 +08:00

19 lines
580 B
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from transformers.models.llama import LlamaForCausalLM
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
"""Freeze all parameters except embeddings."""
for name, params in model.named_parameters():
if "embed_tokens" not in name and "lm_head" not in name:
params.requires_grad = False
else:
params.requires_grad = True
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
for name, params in model.named_parameters():
params.requires_grad = False