LLMの一部レイヤーの重みを固定してLoRAを使わずFP16でファインチューニングする

LLMのファインチューニングを行う際、QLoRAやPEFTで行われている記事が多いが、これらの効率化技術を使わなかった時の性能を確認してみたくなった。

また、大規模データで学習済みのモデルは入力層に近いほどデータの抽象的な特徴を学習していると言われているが、少量の自前データで全部の層をファインチューニングをしてしまうと、入力層に近い所の重みが更新されることで特徴抽出が上手くいかなくなる可能性がある。

さらに、全層ファインチューニングを行うには要求されるマシンスペックが高くなってしまうという事情もある。

そこで、一部の層だけをファインチューニングするという、CNNやBERTでよく使われる手法をLLMでも出来ないか調べてみた。

今回はelyza/ELYZA-japanese-Llama-2-7bで試す。

elyza/ELYZA-japanese-Llama-2-7b · Hugging Face
We’re on a journey to advance and democratize artificial intelligence through open source and open science.

まずはtorchinfoでモデルの構造を確認し、どのレイヤーの重みを固定し、どのレイヤーを学習させるのか決める。

GitHub - TylerYep/torchinfo: View model summaries in PyTorch!
View model summaries in PyTorch! Contribute to TylerYep/torchinfo development by creating an account on GitHub.
pip install torchinfo
from transformers import AutoModelForCausalLM
from torchinfo import summary      


model_name = "elyza/ELYZA-japanese-Llama-2-7b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
)

batch_size = 1
max_len = 8
input_size = (batch_size, max_len)

summary(model, input_size=input_size, dtypes=[torch.int])

torch_dtypeはAmpere以降のGPUならtorch.bfloat16、そうでないならtorch.float16にする。

複数GPUにまたがったモデルはtorchinfoで分析できないので、device_map=”auto”は書かないようにする。

モデルの構造を分析する際にダミーのテンソルが入力されるので、input_sizeは小さいほど実行時間が早くなる。

LLMの入力はintのトークンなので、summaryの引数dtypesは指定する。

指定しないと以下のエラーが出る。

実行するとモデル構造が出力される。

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
LlamaForCausalLM                                        [1, 32, 8, 128]           --
├─LlamaModel: 1-1                                       [1, 32, 8, 128]           --
│    └─Embedding: 2-1                                   [1, 8, 4096]              131,072,000
│    └─ModuleList: 2-2                                  --                        --
│    │    └─LlamaDecoderLayer: 3-1                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-2                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-3                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-4                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-5                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-6                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-7                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-8                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-9                      [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-10                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-11                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-12                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-13                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-14                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-15                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-16                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-17                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-18                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-19                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-20                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-21                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-22                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-23                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-24                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-25                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-26                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-27                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-28                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-29                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-30                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-31                     [1, 8, 4096]              202,383,360
│    │    └─LlamaDecoderLayer: 3-32                     [1, 8, 4096]              202,383,360
│    └─LlamaRMSNorm: 2-3                                [1, 8, 4096]              4,096
├─Linear: 1-2                                           [1, 8, 32000]             131,072,000
=========================================================================================================
Total params: 6,738,415,616
Trainable params: 6,738,415,616
Non-trainable params: 0
Total mult-adds (G): 6.74
=========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 106.38
Params size (MB): 26953.66
Estimated Total Size (MB): 27060.04
=========================================================================================================

Trainable params: 6,738,415,616で、このままだと7Bモデルの全層ファインチューニングになってしまう。

今回は、出力層に近いDecoder 3-32、RMSNorm 2-3、Linear 1-2だけを学習させることにする。

これ以外のレイヤーを固定する方法を調べため、ヒントを求めて画像分類のチュートリアルを見てみる。

Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 2.2.1+cu121 documentation

モデルに含まれるParameterクラスのrequires_grad属性をFalseにすればいいみたい。

Parameter — PyTorch 2.2 documentation

モデルはModuleオブジェクトであり、named_parametersメソッドを使うことで名前とパラメータクラスを返すイテレータにアクセスできる。

Module — PyTorch 2.2 documentation

named_parameters(prefix=”, recurse=True, remove_duplicate=True)
Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
prefix (str) – prefix to prepend to all parameter names.

recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.

remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.

Yields:
(str, Parameter) – Tuple containing the name and parameter

Return type:
Iterator[Tuple[str, Parameter]]

確認してみる。

for name, param in model.named_parameters():
    print(name, param.requires_grad)
model.embed_tokens.weight True
model.layers.0.self_attn.q_proj.weight True
model.layers.0.self_attn.k_proj.weight True
model.layers.0.self_attn.v_proj.weight True
model.layers.0.self_attn.o_proj.weight True
         (中略)
model.layers.30.input_layernorm.weight True
model.layers.30.post_attention_layernorm.weight True
model.layers.31.self_attn.q_proj.weight True
model.layers.31.self_attn.k_proj.weight True
model.layers.31.self_attn.v_proj.weight True
model.layers.31.self_attn.o_proj.weight True
model.layers.31.mlp.gate_proj.weight True
model.layers.31.mlp.up_proj.weight True
model.layers.31.mlp.down_proj.weight True
model.layers.31.input_layernorm.weight True
model.layers.31.post_attention_layernorm.weight True
model.norm.weight True
lm_head.weight True

今回は、Decoder 3-32、RMSNorm 2-3、Linear 1-2だけを学習させることにしたので、出力された名前をコピペしてリストにする。

torchinfoは1-indexなのに対し、named_parametersは0-indexなので通し番号が1つずれるのに注意する。

他のパラメータは学習させず固定したままにするので、requires_grad = Falseにする。

training_layers = [
    "model.layers.31.self_attn.q_proj.weight",
    "model.layers.31.self_attn.k_proj.weight",
    "model.layers.31.self_attn.v_proj.weight",
    "model.layers.31.self_attn.o_proj.weight",
    "model.layers.31.mlp.gate_proj.weight",
    "model.layers.31.mlp.up_proj.weight",
    "model.layers.31.mlp.down_proj.weight",
    "model.layers.31.input_layernorm.weight",
    "model.layers.31.post_attention_layernorm.weight",
    "model.norm.weight",
    "lm_head.weight",
]

for name, param in model.named_parameters():
    if name not in training_layers:
        param.requires_grad = False

もう一度torchinfoで確認。

summary(model, input_size=input_size, dtypes=[torch.int])
=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
LlamaForCausalLM                                        [1, 32, 8, 128]           --

(中略)

│    │    └─LlamaDecoderLayer: 3-30                     [1, 8, 4096]              (202,383,360)
│    │    └─LlamaDecoderLayer: 3-31                     [1, 8, 4096]              (202,383,360)
│    │    └─LlamaDecoderLayer: 3-32                     [1, 8, 4096]              202,383,360
│    └─LlamaRMSNorm: 2-3                                [1, 8, 4096]              4,096
├─Linear: 1-2                                           [1, 8, 32000]             131,072,000
=========================================================================================================
Total params: 6,738,415,616
Trainable params: 333,459,456
Non-trainable params: 6,404,956,160
Total mult-adds (G): 6.74
=========================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 106.38
Params size (MB): 26953.66
Estimated Total Size (MB): 27060.04
=========================================================================================================

Trainable params: 333,459,456と激減しており、LoRAを使わずに一部レイヤーだけをファインチューニング可能なモデルの準備ができた。

後は通法に従って、TrainingArgumentsを定義してTrainerまたはSFTTrainerを実行すれば完了。

実行するGPUに応じてTrainingArgumentsのbf16またはfp16をTrueにする。

training_argument = TrainingArguments(
    (中略)
    bf16=True,
    # fp16 = True,
)

trainer = SFTTrainer(
    model=model,
    (中略)
)

model.config.use_cache = False
trainer.train()

trainer.save_model(save_dir)

大規模言語モデル入門

ChatGPT/LangChainによるチャットシステム構築[実践]入門

コメント

タイトルとURLをコピーしました