Prism Transformer 引入了一种渐进式头调度机制,该机制在不同层之间变化注意力头的数量,早期层使用较少但更宽的头,随着深度增加单调递增。这种方法通过解决早期层与后期层不同的结构需求,在不增加架构开销的情况下挑战了标准的均匀分配。
- 早期层利用每个头的宽子空间(dh=256)来捕获丰富的局部模式,而后期层则使用许多窄头进行专门分解。
- 权重矩阵保持标准的 dmodel×dmodel 形状,使参数量保持不变。
- 总 FLOPs 在数学上对头数量不变,确保计算中立性。
- 2 的幂次方头维度(dh ∈ {256, 128})保持 Tensor Core 对齐以实现吞吐量中立。
- 结果表明,与具有相同 tokens/sec 和墙钟时间的均匀基线相比,每个规模下的验证损失更低。
- 该模型在包括 PIQA、WinoGrande、HellaSwag 和 ARC-Easy 在内的基准测试中取得了增益或持平。
逐层注意力距离分析证实了这种增益是结构性的,因为早期的 Prism 层在转向全局整合之前更多地关注局部。实现仅需在每个注意力层中进行一行更改以使头数量依赖于层。