mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
* support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
19 lines
580 B
Python
19 lines
580 B
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from transformers.models.llama import LlamaForCausalLM
|
|
|
|
|
|
def freeze_non_embeds_parameters(model: LlamaForCausalLM) -> None:
|
|
"""Freeze all parameters except embeddings."""
|
|
for name, params in model.named_parameters():
|
|
if "embed_tokens" not in name and "lm_head" not in name:
|
|
params.requires_grad = False
|
|
else:
|
|
params.requires_grad = True
|
|
|
|
|
|
def unfreeze_parameters(model: LlamaForCausalLM) -> None:
|
|
for name, params in model.named_parameters():
|
|
params.requires_grad = False
|