0%

PrefixTuning 与 GQA 兼容性问题修复


问题


在使用 PEFT prefix tuning 模型的时候发现报错:

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

  File "/opt/conda/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 284, in forward 

    hidden_states, self_attn_weights = self.self_attn( 

                                       ^^^^^^^^^^^^^^^ 

  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl 

    return self._call_impl(*args, **kwargs) 

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl 

    return forward_call(*args, **kwargs) 

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

  File "/opt/conda/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 223, in forward 

    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 

                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

  File "/opt/conda/lib/python3.11/site-packages/transformers/cache_utils.py", line 545, in update 

    self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2

                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 320 but got size 128 for tensor number 1 in the list.

相关设置如下:

prefix_config = PrefixTuningConfig(
task_type=TaskType.FEATURE_EXTRACTION,
num_virtual_tokens=64
inference_mode=False
)
peft_model = get_peft_model(base_hf_model, prefix_config)

排查过程

初步怀疑是因为现在的 LLM 使用 GQA 的技术从而与 PEFT 的默认实现产生冲突。Ptuning v2 通过在 attention 层引入额外的 key 和 value 以增强模型表示,从而实现微调效果。

现在我们尝试打印下模型的结构和关键参数。

Qwen3-0.6b 的结果如下所示:

model_config.num_key_value_heads: 8
model_config.head_dim: 128
model_config.num_attention_heads: 16
model_config.hidden_size:1024

(self_attn): Qwen3Attention(
(q_proj): Linear(in_features=1024, out_features=2048, bias=False) (k_proj): Linear(in_features=1024, out_features=1024, bias=False) (v_proj): Linear(in_features=1024, out_features=1024, bias=False) (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
(q_norm): Qwen3RMSNorm ( (128, ), eps=1e-06)
(k_norm) : Qwen3RMSNorm ( (128, ), eps=1e-06)
)

我们可以看到 k,v 的 注意力头有 8 个, 而 q (attention_heads) 有 16 个,q, k, v 的注意力头数并不相同。

另外 k,v 的头数与头的维度相乘似乎刚好等于hidden_size (8*128=1024)

让我们来看更大一点的模型 Qwen3-4b 的结果:

model_config.num_key_value_heads: 8
model_config.head_dim: 128
model_config.num_attention_heads: 32
model_config.hidden_size: 2560

(self_attn): Qwen3Attention(
(q_proj): Linear (in_features=2560, out_features=4096, bias=False) (k_proj): Linear(in_features=2560, out_features=1024, bias=False) (v_proj): Linear(in_features=2560, out_features=1024, bias=False) (o_proj): Linear (in_features=4096, out_features=2560, bias=False)
(q_norm): Qwen3RMSNorm ( (128, ), eps=1e-06)
(k_norm) : Qwen3RMSNorm ( (128, ), eps=1e-06)
)

我们可以看到 k,v 的 注意力头有 8 个, 而 q (attention_heads) 有 16 个,q, k, v 的注意力头数并不相同。

另外 k,v 的头数与头的维度相乘似乎刚好等于hidden_size (8* 128=1024)

让我们来看更大一点的模型 Qwen3-4b 的结果:

解决方法


只需要将 num_attention_heads 锚定 nums_kv_heads, token_dim 锚定 model_config.num_key_value_heads * model_config.head_dim 即可解决。

model_config = model.config
prefix_config = PrefixTuningConfig(
task_type=TaskType.FEATURE_EXTRACTION,
num_virtual_tokens=getattr(self.model_config, 'num_virtual_tokens', 20),
inference_mode=False,
token_dim=model_config.num_key_value_heads * model_config.head_dim,
num_layers = model_config.num_hidden_layers,
num_attention_heads = getattr(model_config, "num_key_value_heads", model_config.num_attention_heads)
)