Dark Mode
Read Original on Substack

Token Compression: Reducing Attention Waste?

Using an LLM to compress multiple tokens into one token

The AI Dude
Aug 11, 2023
8 min read
Back to Home

Research Question

Can we compress multiple tokens into one token and decode them back, creating more efficient transformers?

LLAMA-7b has a hidden size of 4096. Its vocabulary size is 32k. It needs a vector of size 4096 to represent a token like "is" or even a half word. For LLAMA-2-70b, the hidden size is 8192.

I always wondered if this is just too big to represent just one token. For example, if we look at something like sentence transformer, CLIP text encoder or OpenAI embedding APIs, etc. They seem to compress large amount of text into a single vector. Just to have a perspective, OpenAI embedding API can represent a text with up to 8k tokens with just an embedding size of 1536.

Just to be sure, these text embedders have a different objective compared to let say a transformer. Embedding models are optimized for search-related objectives. And Language models are optimized for next token prediction.

LLAMA-7b

Hidden Size 4096
Vocab Size 32k
Purpose Next Token Prediction

OpenAI Embeddings

Embedding Size 1536
Context Length 8k tokens
Purpose Search & Similarity

So, I wondered if I could combine this somehow. Can we compress multiple tokens into one token and decode those tokens from that one token? Then we can train a new transformer that operates on these 'group tokens' instead of individual ones.

There can be many potential use-cases for such system. Like online knowledge injection, vocabulary extension, and next phrase prediction etc. I leave those exploration to the future work.

Model Architecture

Token compression model architecture diagram
Model architecture with frozen LLM, LoRA, and special embedding tokens

Frozen LLM Base

Pythia 1.4b with 2048 hidden dimension - weights frozen for parameter efficiency

LoRA Adaptation

Low-rank adaptation layers that are trainable while keeping base model frozen

Special Tokens

<Embed> and <Decode> tokens for compression and reconstruction stages

Two-Stage Process

1

Encoding Stage

Text tokens + <Embed> token → Compressed embedding (context token)

Inspired by AutoCompress paper
2

Decoding Stage

Context token + <Decode> token → Perfect reconstruction

Next token prediction for reconstruction

Experiment 1: 1-16 Token Compression

Dataset

The Pile dataset, chunks of 1-16 tokens + EoS

Batch Size

64, parameter efficiently fine-tuned

Iterations

~250k iterations until convergence

Loss curve for 1-16 token compression showing convergence to ~0.05
Loss curve showing excellent convergence to ~0.05 (almost perfect reconstruction)
~0.05 Final Loss
~8 Avg Tokens Compressed
2048 Hidden Size
Sample outputs from 1-16 token compression model
Sample outputs at temperature 0.1 showing near-perfect reconstruction

Key Finding

A hidden size of ~2048 can easily compress an average of 8 tokens into a single vector with near-perfect reconstruction capability.

Experiment 2: Pushing the Limits (1-32 Tokens)

Next step was to investigate how far can I push this. I decided to further fine-tuned this model with token ranges from 1-32 while keeping everything else same.

Extended Range

1-32 tokens (double the previous range)

Batch Size

32 (reduced due to memory constraints)

Same Iterations

Same number of iterations as Experiment 1

Loss curve for 1-32 token compression showing higher final loss of ~0.2
Loss curve showing convergence to ~0.2 (4x higher than previous model)

Experiment 1

~0.05
Final Loss
Excellent

Experiment 2

~0.2
Final Loss
Moderate
Sample outputs from 1-32 token compression model
Sample outputs showing the model struggling with longer contexts

Insight

The model struggles to compress longer contexts. From the loss curve, it seems if training continued, it might converge, but compute limitations prevented further investigation.

Summary

This project is continuation of my previous post on a quest for long context. I tried to work on many ideas like AutoCompress mentioned above but due to the speed of research, I realized it would be a matter of time before someone would train open source long context llms. For example, this effort (LLAMA-2-32k) by Together Computers is a really good one. So I decide to explore in different directions.

Key Contributions

Efficiency Investigation

Demonstrated that transformer hidden sizes may be overkill for single token representation

Simple Architecture

Presented a simple method to convert pre-trained LLMs into token compression models

Compression Results

Showed 2048 embedding size can compress ~8 tokens losslessly, ~16 tokens with acceptable loss

Future Ideas

Token compression model may not have immediate benefits out of the box. So, I want to share some ideas where and how it may be used. I will also be exploring some of these ideas myself as well.

Transformer Insights

Understanding how transformers compress and decompress concepts, leading to more efficient architectures like growing transformers with increasing embedding dimensions.

Growing Vocabulary

Training transformers with expanding vocabularies for better context length and output quality, enabling domain adaptation and multi-language/domain architectures.

Decode Time Reduction

Massive reduction in inference time: large transformer predicts groups of tokens, smaller transformer decodes the groups.

Context Length Extension

Increasing context length by ~8x through token compression, enabling longer document processing capabilities.

Key Challenge

The main challenge is how to optimally group tokens. A naive approach of grouping adjacent tokens of size 8 may fail when complex information needs compression.

I hope you have enjoyed this essay. Complete source code with model weights is available on GitHub gist. While it is a bit messy, I plan to share a cleaner repo later.

Code Available

Full implementation and model weights available on GitHub

View on GitHub

Topics

Token Compression LLM Efficiency Transformers LoRA Pythia AutoCompress Context Length
Back to Home