The Prism Transformer introduces a progressive head schedule that varies the number of attention heads across layers, starting with fewer, wider heads in early layers and increasing the count monotonically with depth. This approach challenges the standard uniform allocation by addressing the distinct structural needs of early versus late layers without adding architectural overhead.
- Early layers utilize wide per-head subspaces (dh=256) for rich local pattern capture, while late layers use many narrow heads for specialized decomposition.
- Weight matrices retain standard dmodel×dmodel shapes, keeping parameter count neutral.
- Total FLOPs remain mathematically invariant to head count, ensuring compute neutrality.
- Power-of-2 head dims (dh ∈ {256, 128}) preserve Tensor Core alignment for throughput neutrality.
- Results show lower validation loss at every scale compared to uniform baselines with identical tokens/sec and wall-clock time.
- The model achieves gains or parity on benchmarks including PIQA, WinoGrande, HellaSwag, and ARC-Easy.
Per-layer attention-distance analysis confirms the gain is structural, as early Prism layers attend more locally before flipping to global integration. Implementation requires only a one-line change per attention layer to make head count layer-dependent.