Every year, dozens of papers attack the same problem: self-attention scales as O(n²d), and long contexts are expensive. The standard playbook? Reduce the number of tokens that interact — local windows, kernel approximations, token-level sparsity. It works, sort of. But there's always a quality tax. The models get faster and dumber in roughly equal measure.

Paper: "Scaling Attention via Feature Sparsity" — Yan Xie, Tiansheng Wen, Tangda Huang, Bo Chen, Chenyu You, Stefanie Jegelka, Yifei Wang (Xidian University, Stony Brook, TUM, MIT, Amazon AGI SF Lab). Accepted at ICLR 2026.

Link: arXiv:2603.22300 | Code

This paper asks a deceptively simple question: what if we sparsified the other axis? Instead of choosing which tokens talk to each other, what if every token still talked to every other token — but they only compared notes on a handful of feature dimensions?

The result is Sparse Feature Attention (SFA), and the numbers suggest this overlooked axis might be the more productive one to compress.

The Core Insight: Features Are Redundant, Tokens Aren't

Here's the intuition. When a query token computes attention against a key token, standard attention computes a dot product across all d feature dimensions (typically 128 per head). But most of that information is noise for any given pair. Only a few dimensions carry the signal that determines whether these two tokens should attend to each other.

SFA makes this explicit. For each query and key vector, it keeps only the top-k largest-magnitude entries and zeros out the rest. Attention scores are then computed only where the active dimensions overlap between a query and a key.

The math is elegant. With k-sparse codes, the cost of computing QK⊤ drops from Θ(n²d) to Θ(n²k²/d). That's a factor of (k/d)² reduction. With the paper's default settings of d=128 and k=16, that's a 64× theoretical reduction in attention compute. Scale up to d=1024 with k=32 and you're looking at 1024×.

This is orthogonal to token-level sparsity. You could combine both — reducing which tokens interact AND how many features they compare — for compounding gains.

How It Actually Works

The mechanism is refreshingly straightforward:

1. Sparse projection: Given dense Q and K matrices, apply a row-wise Top-k operator that keeps only the k largest absolute values per row, zeroing the rest. This creates k-sparse query and key representations.

2. Sparse matrix multiplication: Store Q̃ in CSR (Compressed Sparse Row) format and K̃ᵀ in CSC (Compressed Sparse Column) format. Attention scores are computed via sparse matrix multiplication — only overlapping non-zero coordinates contribute.

3. Straight-through gradients: During backpropagation, gradients flow only through the selected coordinates. Unselected dimensions get zero gradient. Simple, but it works.

4. Values stay dense: Only Q and K get sparsified. The value matrix V remains full-dimensional, preserving the richness of what gets passed forward.

The attention formula for a query-key pair becomes:

s_ij = (1/√d) × Σ_{u ∈ S_i ∩ S_j} q̃_i,u × k̃_j,u

Where S_i and S_j are the active feature supports of token i and j. If two tokens happen to activate completely different feature dimensions, their attention score is zero — they simply don't see each other through the feature lens.

FlashSFA: Making It Actually Fast

Here's where many "efficient attention" papers fall apart — the theory is beautiful but the GPU implementation is slower than dense FlashAttention because of irregular memory access patterns.

The authors anticipated this and built FlashSFA, an IO-aware CUDA kernel that extends FlashAttention's tiling strategy to work with sparse feature intersections. Instead of computing dense tile multiplications, FlashSFA iterates over active features within each tile, intersects their supports, and scatter-adds into a compact score buffer. The buffer feeds directly into online softmax — no n×n score matrix is ever materialized.

This is critical. Without FlashSFA, you'd still need O(n²) memory for the attention matrix even if you computed fewer FLOPs. FlashSFA gives you both the compute AND memory savings.

The Numbers: Where It Gets Interesting

The paper evaluates SFA across GPT-2 (Small and Medium), Qwen3-0.6B pretraining, and fine-tuning experiments on Qwen3-0.6B and Qwen3-4B.

Pretraining quality (k=8):

  • GPT-2 Small: PPL 31.51 vs 29.85 dense — a modest 5.6% increase in perplexity
  • Qwen3-0.6B: PPL 4.81 vs 4.66 dense — only 3.2% higher
  • Average accuracy across benchmarks (PiQA, LAMBADA, ARC-e/c, HellaSwag) stays within 1-2 points of dense baselines

Compare that to the "just make embeddings smaller" baseline (short embeddings with half the hidden size):

  • Qwen3-0.6B with short embeddings: PPL jumps to 6.03 (29% worse than dense)
  • Average accuracy drops to 36.68 vs SFA's 38.94

SFA preserves expressivity because it keeps the high-dimensional space — it just activates different slices of it per token.

Efficiency gains:

  • Up to 2.5× speedup over dense attention
  • ~49% FLOP reduction
  • ~41% KV-cache memory reduction
  • At 65k context with 256-dim heads, SFA reduces latency by over 10× compared to dense

Long-context retrieval (Needle-in-a-Haystack):

  • At 32k context, SFA with k=8 achieves 82% accuracy vs dense attention's 80%
  • SFA actually outperforms dense attention on length generalization, maintaining higher accuracy at unseen lengths
  • 1.3× faster generation at 32k with k=8

Fine-tuning pretrained models:

  • On Qwen3-4B, SFA (Top-16) stays within 1-3 accuracy points of dense fine-tuning across GSM-8K, Arxiv, PubMed, and NIAH benchmarks
  • NIAH retrieval performance is essentially identical between dense and sparse fine-tuning

Why This Matters for Practitioners

If you're building or deploying LLMs, here's why you should care:

1. Context window scaling. The paper's theoretical analysis suggests SFA could extend a 1M context window to 64M at the same compute cost, or even 1B with higher-dimensional heads. That's not a typo — it's the (k/d)² scaling factor at work. Whether real-world implementations hit those numbers is another question, but the headroom is enormous.

2. KV-cache is the real bottleneck. During inference, attention compute is often less of a bottleneck than KV-cache memory. SFA stores only k values per token instead of d, directly shrinking cache by (1 - k/d). At k=8, d=128, that's 94% smaller KV-cache per head. For long-context serving, this alone could be transformative.

3. It composes with everything else. Token-level sparsity, paging, quantization — SFA is orthogonal to all of them. You could run SFA inside a sparse attention mask, with quantized KV-cache, behind a paged memory system. Each layer of optimization multiplies rather than conflicts.

4. Adaptation is viable. You don't need to pretrain from scratch. The paper shows that fine-tuning dense pretrained models into SFA works with a regularized objective that encourages sparse attention outputs to approximate dense ones. The quality gap at 4B scale is small enough to be practical.

5. The code is public. The GitHub repo includes the FlashSFA kernel, which is the hardest part to implement yourself.

What Could Go Wrong

The paper is honest about limitations, and a few stand out:

GPU hardware support for sparse ops is immature. The authors note that "sparse tensor products require stronger support from GPU hardware and CUDA libraries to fully unlock their efficiency." Current GPUs are optimized for dense tensor operations. As sparse support improves (NVIDIA has been investing here), the practical gains should grow.

Very aggressive sparsity hurts. At k=2 or k=4, quality degrades noticeably. The sweet spot seems to be k=8 to k=16 for current model sizes. Adaptive sparsity budgets — where different heads or layers use different k — could help but aren't explored.

Arithmetic reasoning is sensitive. GSM-8K performance under fine-tuning shows a slightly larger gap than other tasks, suggesting that math reasoning relies on broader feature interactions. This matches intuition — precise numerical computation probably needs more feature dimensions than retrieval or comprehension.

The short-context regime isn't a win. Below 4k-8k tokens, the overhead of sparse indexing and irregular memory access means dense FlashAttention is still faster. SFA's advantages emerge at longer contexts where the quadratic cost dominates.

Our Take: This paper gets at something important that the efficient attention community has mostly missed. For years, the field has been hacking away at which tokens attend to which, while leaving the per-interaction cost untouched. SFA shows that the feature axis has far more room for compression than anyone was exploiting.

The ICLR 2026 acceptance is well-deserved. The theoretical framework is clean, the implementation is practical (not just asymptotically better), and the experiments are thorough across both pretraining and fine-tuning settings. The fact that SFA actually improves length generalization on NIAH — outperforming dense attention — is a genuinely surprising result that suggests sparse features might be a useful inductive bias, not just an approximation.

For anyone running long-context inference, the KV-cache reduction alone makes this worth investigating. For model architects, the composability with token-level sparsity opens a design space that could push practical context windows well beyond current limits.

Feature sparsity might be the most underexplored axis in transformer efficiency. This paper makes a strong case that it shouldn't stay that way.

Xie, Y., Wen, T., Huang, T., Chen, B., You, C., Jegelka, S., & Wang, Y. (2026). Scaling Attention via Feature Sparsity. ICLR 2026. arXiv:2603.22300.

Want AI That Thinks Smarter, Not Harder?

Our OpenClaw Field Guide shows you how to build AI systems that use cutting-edge techniques like sparse attention to deliver more with less.

Get the Field Guide — $10 →