Mar 24 2025
Compressing KV cache memory by half with sparse attention
The modern large language model (LLM) landscape relies heavily on dense transformer decoders for their strong performance across diverse tasks. However, extending models to longer contexts—critical for domain-specific data and extended chat histories—creates a major KV cache memory bottleneck, spurring architectural redesigns and post-training techniques (e.g., grouped-query attention [Ainslie et al], state-space models [Gu et al., Dong et al.], sparse attention [Jiang et al., Yang et al.], and KV cache quantization [Hooper et al.]).
In this work, we adapt a dense transformer decoder (using Llama3.1-8B-Instruct as an example) using mostly sparse attention across layers and applying a set of techniques, namely:
- synthetic long instruction data generation for fine-tuning,
- selective dense layers based on sensitivity to sparse attention,
- specialized sliding-window attention with global “sink” tokens,
- auxiliary “memory” tokens added to the input sequence for enhanced long-context processing.
This nearly halves KV cache memory, boosts performance on some long-context benchmarks (LongBench [Bai et al.]), and maintains comparable results on others (HELMET [Yen et al.]).
We are releasing two models – Llama-3-CBHybridL-8B (with 25 sparse attention layers out of 32) and Llama-3-CBHybridM-8B (with 28 sparse attention layers out of 32). We also release the synthetic long instruction-following data we used to fine-tune these models:
- Llama-3-CBHybridM-8B model HF page
- Llama-3-CBHybridL-8B model HF page
- Synthetic long SFT dataset HF page
Below are results for our models in terms of memory-accuracy trade-off on LongBench and HELMET evaluation suites:
Methods.
Sliding window attention with global sink tokens.
Attention sinks is a phenomenon that is commonly observed in pre-trained LLMs in which a few tokens (typically the very first token(s) in the sequence [Xiao et al., Han et al.]) hold a significant mass of the attention score regardless of the semantic meaning of the token. This phenomenon has been observed in Llama3 models, where the removal of the first [BOS] token from its inputs significantly degrades model performance. Using vanilla sliding window attention on a pre-trained dense decoder model like Llama3 thus leads to a significant drop in accuracy. Since the model through its pre-training has learned to offload the excess in attention scores to those specific sink tokens, removing them from the KV cache once the sequence length goes beyond the sliding window size has been shown empirically to destroy model performance.
A workaround proposed for this has been to always keep the attention sink tokens as a part of KV cache, effectively leading to a lambda-shaped attention mask pattern instead of vanilla sliding window [Xiao et al., Han et al.]. StreamingLLM work [Xiao et al.] also argues that it’s necessary to update position IDs of tokens inside the KV cache at each decoding step, making it necessary to undo and redo rotary positional embeddings (RoPE) rotations of KV tensors if the model is using RoPE. LM-Infinite work [Han et al.] further proposed to clamp the maximal relative distance beyond the sliding window size as a way to prevent growing relative distance from sink tokens at the start of the sequence and the sliding window.
Our aim was to apply this kind of sparse attention at training time. Because recipes proposed by StreamingLLM (adjusting position IDs) and LM-Infinite (clamped relative distances) cause significant computational overhead, we have instead focused on training models with just the sparse lambda-shaped attention mask, without any adjustments made to the RoPE mechanism. We found this to work reasonably well for Llama3 models, which was also motivated by observing performance of LLama3.1-8B-Instruct when switching from dense self-attention to sparse variants in all layers without any additional training:
By looking at the performance of sparse-attention Llama3.1-8B-Instruct on different task categories of the HELMET suite, we were able to see that (i) vanilla sliding window performs poorly out-of-the-box, (ii) keeping the first 512 tokens of the sequence vs. just 8 tokens leads to a significant boost in performance in certain task categories (e.g. Recall and LongdocQA), and (iii) applying LM-Infinite-style clamping to relative distances in RoPE does not lead to significant changes in performance compared to computing RoPE with the original position IDs of the full input sequence. The latter motivated us to train the models only applying the lambda mask in the self-attention mechanism, however we still apply the LM-Infinite-style clamping at inference as we see slightly better performance for fine-tuned models in that case. We follow the same settings of sliding window size of 8192 and 512 attention sink tokens (width of the vertical in the lambda mask) for all sparse layers.
One-shot sparse layer selection heuristic.
When switching from dense to sparse attention mechanism in the model, one variable that controls the accuracy vs. memory footprint trade-off is which layers (and how many) are kept with dense attention and which layers are converted to perform sparse attention. We have evaluated two strategies for balancing dense and sparse attention in the model: (a) regular pattern as proposed by [Rajput et al.] and others with 1 dense layer followed by N sparse layers, (b) selecting M most sensitive layers to be kept with dense attention at an equivalent KV cache memory budget. We ultimately found that the latter strategy produced higher quality models, as we found significant variation from layer to layer in how much model quality is affected by converting the layer from dense to sparse attention.
To determine the specific layers that are sensitive to sparse attention, we performed a simple one-shot analysis by iteratively going over model’s layers and setting only a single layer at a time to sparse attention while keeping others dense. The difference in performance relative to that of the fully dense model provides a signal as to how sensitive a particular layer is to switching to sparse attention. We used the HELMET RAG subset (16K context length) and defined the sensitivity score as the ratio of accuracy score of the sparse model (with a single sparse attention layer at specified index) to the accuracy score of the fully dense model. We ran the evaluation 3 times with different random seeds (corresponding to 3 different data subsets from the source datasets) and used the mean value of the sensitivity score to detect the sensitive layers:
Based on this one-shot analysis, we can see that layer sensitivity to sparsifying attention is very non-uniform and does not constitute a monotonic trend with model depth. Moreover, select layers appear to be highly sensitive, whereas for other certain layers performance even improves over the baseline when a single attention layer is sparsified. Such a specific non-uniform pattern in layer sensitivity could be attributed to the retrieval heads phenomenon recently observed in different LLMs [Wu et al.] (i.e. layers containing more retrieval heads suffer more from limiting their attention span). It would be interesting to relate the layer-wise distribution of retrieval heads in the models to the observed sensitivity pattern.
Based on the mean sensitivity scores, we have selected the top 4 most important layers to keep dense (equivalent memory budget to a 1:7 regular pattern) and top 7 layers (equivalent memory budget to a 1:4 regular pattern). We have also trained the models with regular dense-sparse patterns (1:3, 1:7) to compare against models with selected layers.
We can see how the different dense-sparse configurations compare by evaluating them in a zero-shot manner, without any additional training done on the model:
KV cache memory size in the above plot is estimated (in GB) is based on the number of tensor elements needed to store the KV activation tensors to output first token during the decoding phase given an input of length equal to 16384 (sparse attention layers only needing to store 8192+512 KV cache elements at inference time). As seen in the plot, configurations where we convert the least sensitive layers from dense to sparse provide the best accuracy-memory trade-off in the zero-shot regime, which also further translated into accuracy gains of the trained models.
Auxiliary memory tokens.
Inspired by recent research in augmenting transformers with memory mechanisms [Mu et al., Zhang et al.], as well as by advances in inference-time compute scaling (as well as previous research on adding auxiliary “thinking” tokens [Hao et al., Goyal et al.]), we have investigated if adding auxiliary tokens at training and inference time would help the model better aggregate information from the local context thus improving long context performance. We have considered a very simple model of memory tokens by periodically inserting a given number of these auxiliary tokens in the input sequence, interleaving memory and regular input tokens. We have added only 2 extra learnable token embeddings (<memory> meaning to signify “memory“ and </memory> to signify “end of memory“) without any further modifications done to model architecture or the attention masking pattern. Thus, after performing the modification of the input sequence, every M tokens of the original input are followed by a sequence of K memory tokens of the form
<memory> <memory> … <memory> </memory>
where the <memory> and </memory> embeddings are initialized to be the mean of all token embeddings of the model and are learned. We keep the memory tokens in the input during training and optionally add them at inference time. Below we show ablations for the optimal values of M and K for the fully dense model (using same training pipeline as for dense/sparse models but without introducing sparse attention at 32K length steps):
As can be seen from both dense and sparse model ablations, the values yielding the best performance in long context tasks are 256 regular tokens (corresponding to 256 tokens of original input that the model aggregates information from), followed by 8 memory tokens. We have used these values in the rest of our runs with memory tokens.
Synthetic long context data and position ID shifting.
Two extra ingredients of our training recipe are (a) long synthetic instruction-following data produced using RAFT-like augmentations [Zhang et al.] and (b) shifting position IDs in RoPE to simulate longer sequence lengths [Zhu et al.]. We have discussed both of these methods in depth in our previous blogpost on context window length extension. We have used long synthetic datasets described there to train our sparse-attention models (also see the training recipe section below). We have also used PoSE-style position shifting (as described in our previous blogpost) which we apply only during the short SFT stage (with actual sequence length of 8K and simulated target length of 32K), so that the model learns to better handle larger relative distances during the instruction tuning phase with dense attention.
Training recipe.
Our training recipe consists of continued pre-training (CPT) followed by supervised fine-tuning on instruction following data (SFT) divided into 4 stages: (i) short CPT stage (8K context length) with high-quality data used as a “warmup“ stage before CPT is done with sparse attention on samples longer than the sliding window size, (ii) long CPT stage (32K context length) to adapt the model to the introduced sparse attention pattern, (iii) short SFT stage (8K context length) with high-quality data and dense attention to restore instruction following capabilities of the model, and (iv) long SFT stage (32K context length) to allow the model to tackle instruction following with sparse attention. Below is a detailed description of the data mixes and hyper-parameters used to produce our final models.
Stage1. CPT@8K
“Warmup” stage before doing CPT on 32K inputs with sparse attention. This stage also serves to effectively learn the embedding parameters for memory tokens, whenever they are added.
Data mix:
- 38% FineWeb (multilingual)
- 17.6% Cosmopedia
- 11.8% OpenWebMath
- 11.8% Knowledge Pile
- 5.8% Deepmind-math
- 8.7% Proof-Pile-2 (2.9% algebraic-stack, 5.8% arxiv)
- 5.8% UltraTextbooks
Training hyper-parameters:
- ~6.6B tokens total, batch size 2.6M tokens,
- linear learning rate warmup to 5e-6 first 8% of steps, then cosine decay to 1e-8,
- zero weight decay.
Stage2. CPT@32K
CPT stage with sparse attention (sliding window attention, window size 8192, with 512 attention sinks).
Data:
- 100% LongDataCollections
Training hyper-parameters:
- ~7.4B tokens total, batch size 2.95M tokens,
- linear learning rate warmup to 5e-6 first 8% of steps, then cosine decay to 1e-8,
- zero weight decay.
Stage3. SFT@8K
Instruction tuning with dense attention on high-quality SFT data. We also apply position shifting (with target sequence length of 32K) during this stage in some of our experiments to better adapt the dense layers in the model for long inputs.
Data mix:
- 70% Magpie-Pro-300K-Filtered
- 15% OpenMathInstruct-2
- 15% SystemChat
Training hyper-parameters:
- ~1B tokens total, batch size 1M tokens,
- linear learning rate warmup to 5e-6 first 7.5% of steps, then cosine decay to 1e-8,
- zero weight decay.
Stage4. SFT@32K
Long instruction following fine-tuning with sparse attention.
Data mix:
- 77.5% Cerebras synthetic data:
- 50% RAFT-augmented ConvQA,
- 7.5% RAFT-augmented RAG-TGE,
- 7.5% RAFT-augmented RAG-TGE in Chinese,
- 7.5% RAFT-augmented ConvQA with syntactic questions,
- 5.0% Distractor-augmented NarrativeQA
- 22.5% Open-source long instruction-tuning datasets:
- 7.5% LongDataCollections (SFT subset – BookSum, Multi-passage QA),
- 7.5% LongWriter6k,
- 7.5% LongAlpaca12k
Training hyper-parameters:
- ~1.76B tokens total, batch size 1.96M,
- linear learning rate warmup to 1e-6 first 6.6% of steps, then cosine decay to 1e-8,
- zero weight decay.
Together, these four training stages constitute a training budget of about ~17B tokens. Our decision to not use weight decay is based on existing recipes [Peng et al.], and we have also found that using a lower peak learning rate during the last long SFT stage yields better long-context performance results overall.
Results.
Effect of fine-tuning on models with 1:N dense/sparse pattern.
First, we have trained and evaluated models with a regular spacing pattern of 1 dense attention layer followed by N sparse layers:
As can be seen, our training pipeline brought about performance gains across different configurations, with 1:1 model performing close to the dense baseline. Notably, the 1:7 config that outperforms the 1:3 one in the zero-shot setting is also benefitting from higher gains after training. This is, however, not reflected in LongBench performance:
Below is the breakdown of performance scores on HELMET by task category:
Effect of fine-tuning on models with selected dense layers.
Next, we compared the performance of models with a regular spacing of dense and sparse attention layers vs. the models where attention type was selected based on the layer sensitivity heuristic:
The models with dense layers selected based on sensitivity clearly outperformed the regularly spaced models. Furthermore, through applying position shifting during the short SFT stage, we were able to obtain better performing models with the same KV memory budget, particularly for the model with 7 dense layers.
The models with dense layers selected from sensitivity analysis outperform the regularly spaced configs on the LongBench benchmark as well:
Notably, the model with 7 dense layers outperforms the dense baseline both on the full LongBench suite as well as the English-only subset of tasks (and the model with 4 dense layers also outperforms the dense baseline on the English-only subset):
Effect of memory tokens.
We have further studied how model performance is affected when memory tokens are interleaved with input sequence tokens with the optimal config we have found (8 memory tokens per 256 regular tokens). Interestingly, we’ve found that inserting the memory tokens in the input sequence only at training time still brings about performance gains at inference, even though it does not have any impact on the input sequences and hence size of the KV cache. We hypothesize that this effect could arise from the model using tokens with low semantic content (such as punctuation tokens [Razzhigaev et al.]) as memory tokens, aggregating information from the rest of the sequence in the KV cache of these tokens. Performance is further improved when memory tokens are added at inference time, at an expense of a slight increase in KV cache memory usage (~3% larger KV cache in the dense layers).
Below are results for the model with 7 dense models selected from sensitivity analysis:
We have observed improved performance when memory tokens are added during training in both model configs, and we see the best results when memory tokens and position ID shifting during the short SFT stage are combined. Similar results are observed for LongBench:
We name our best models (trained with position shifting and memory tokens) with 4 and 7 dense layers as Llama-3-CBHybridM-8B (medium) and Llama-3-CBHybridL-8B (large), respectively.
Conclusion.
We have demonstrated that, using a combination of techniques, it is possible to fine-tune a pre-trained dense-decoder LLM with sparse attention so as to cut its KV cache memory footprint by nearly half while almost matching, and in some cases outperforming, the original dense model in long-context benchmarks. The specific techniques used to achieve these results were preserving dense attention in in the most sensitive layers of the pre-trained model, introducing auxiliary memory tokens, position ID shifting in the RoPE mechanism, and training on synthetic long instruction-following data. We have demonstrated the effectiveness of our approach using Llama-3.1-8B-Instruct. Further work will consider more intricate layer-wise sparsity patterns [Jiang et al.], combining attention sparsity with quantization and inter-layer KV sharing.
The models were built with Cerebras Model Zoo 2.4 and we are open-sourcing models on the HuggingFace hub.
@misc{cerebras2025cb-hybrid-llama,
author = {Lazarevich, Ivan and Hassanpour, Mohammad and Venkatesh, Ganesh},
title = {Compressing KV cache memory by half with sparse attention},
month = {March},
year = {2025},
howpublished = {\url{https://www.cerebras.ai/blog/compressing-kv-cache-memory-by-half-with-sparse-attention}}
}
References.
Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245.
Gu, A., & Dao, T. (2023). Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752.
Dong, X., Fu, Y., Diao, S., Byeon, W., Chen, Z., Mahabaleshwarkar, A. S., ... & Molchanov, P. (2024). Hymba: A hybrid-head architecture for small language models. arXiv preprint arXiv:2411.13676.
Liu, A., Feng, B., Wang, B., Wang, B., Liu, B., Zhao, C., ... & Xu, Z. (2024). Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model. arXiv preprint arXiv:2405.04434.
Jiang, H., Li, Y., Zhang, C., Wu, Q., Luo, X., Ahn, S., ... & Qiu, L. (2024). Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention. arXiv preprint arXiv:2407.02490.
Yang, A., Yu, B., Li, C., Liu, D., Huang, F., Huang, H., ... & Zhang, Z. (2025). Qwen2. 5-1M Technical Report. arXiv preprint arXiv:2501.15383.
Hooper, C., Kim, S., Mohammadzadeh, H., Mahoney, M. W., Shao, Y. S., Keutzer, K., & Gholami, A. (2024). Kvquant: Towards 10 million context length llm inference with kv cache quantization. arXiv preprint arXiv:2401.18079.
Rajput, S., Sheng, Y., Owen, S., & Chiley, V. (2024). Inference-friendly models with mixattention. arXiv preprint arXiv:2409.15012.
Bai, Y., Lv, X., Zhang, J., Lyu, H., Tang, J., Huang, Z., ... & Li, J. (2023). Longbench: A bilingual, multitask benchmark for long context understanding. arXiv preprint arXiv:2308.14508
Yen, H., Gao, T., Hou, M., Ding, K., Fleischer, D., Izsak, P., ... & Chen, D. (2024). Helmet: How to evaluate long-context language models effectively and thoroughly. arXiv preprint arXiv:2410.02694.
Xiao, G., Tian, Y., Chen, B., Han, S., & Lewis, M. (2023). Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453.
Han, C., Wang, Q., Xiong, W., Chen, Y., Ji, H., & Wang, S. (2023). Lm-infinite: Simple on-the-fly length generalization for large language models. arXiv preprint arXiv:2308.16137.
Zhang, T., Patil, S. G., Jain, N., Shen, S., Zaharia, M., Stoica, I., & Gonzalez, J. E. (2024). Raft: Adapting language model to domain specific rag. arXiv preprint arXiv:2403.10131.
Zhu, D., Yang, N., Wang, L., Song, Y., Wu, W., Wei, F., & Li, S. (2023). Pose: Efficient context window extension of llms via positional skip-wise training. arXiv preprint arXiv:2309.10400.
Hao, S., Sukhbaatar, S., Su, D., Li, X., Hu, Z., Weston, J., & Tian, Y. (2024). Training large language models to reason in a continuous latent space. arXiv preprint arXiv:2412.06769.
Mu, J., Li, X., & Goodman, N. (2024). Learning to compress prompts with gist tokens. Advances in Neural Information Processing Systems, 36.
Goyal, S., Ji, Z., Rawat, A. S., Menon, A. K., Kumar, S., & Nagarajan, V. (2023). Think before you speak: Training language models with pause tokens. arXiv preprint arXiv:2310.02226.
Zhang, P., Liu, Z., Xiao, S., Shao, N., Ye, Q., & Dou, Z. (2024). Soaring from 4k to 400k: Extending llm’s context with activation beacon. arXiv preprint arXiv:2401.03462, 2(3), 5.
Peng, B., Quesnelle, J., Fan, H., & Shippole, E. (2023). Yarn: Efficient context window extension of large language models. arXiv preprint arXiv:2309.00071.
Wu, W., Wang, Y., Xiao, G., Peng, H., & Fu, Y. (2024). Retrieval head mechanistically explains long-context factuality. arXiv preprint arXiv:2404.15574.
Razzhigaev, A., Mikhalchuk, M., Rahmatullaev, T., Goncharova, E., Druzhinina, P., Oseledets, I., & Kuznetsov, A. (2025). LLM-Microscope: Uncovering the Hidden Role of Punctuation in Context Memory of Transformers. arXiv preprint arXiv:2502.15007.