--- library_name: transformers tags: [] pipeline_tag: text-generation license: mit --- # Transformer (Pro) Model Checkpoint for the Forgetting Transformer Paper The final checkpoint for the 760M-parameter Transformer (Pro) model in the main experiment of the ICLR 2025 paper [Forgetting Transformer: Softmax Attention with a Forget Gate](https://arxiv.org/abs/2503.02130). ## Model Details ### Model Description - **Developed by:** Zhixuan Lin - **Model type:** Transformer (Pro) - **Language(s) (NLP):** English - **License:** MIT ### Model Sources - **Repository:** https://github.com/zhixuan-lin/forgetting-transformer - **Paper:** https://arxiv.org/abs/2503.02130 ## Uses ### Direct Use First, install the `forgetting-transformer` repository as a Python package and some needed dependencies (we pin the versions to make sure that this works, but you don't have to): ```bash # We recommend you keep track of the commit hash you used. We may introduce breaking changes in the future. # First, uninstall to prevent potential issues pip uninstall forgetting_transformer && pip install -U git+https://github.com/zhixuan-lin/forgetting-transformer pip install pytest einops numpy pip install torch==2.4.0 pip install transformers==4.44.0 # No guarantee other commits would work; we may fix this later pip install --no-deps --force-reinstall git+https://github.com/sustcsonglin/flash-linear-attention.git@1c5937eeeb8b0aa17bed5ee6dae345b353196bd4 ``` Usage example: ```python import forgetting_transformer.model.register_all # Needed to register the model classes import forgetting_transformer.tokenizer # Needed to register the tokenizer class from transformers import AutoModelForCausalLM, AutoTokenizer import torch model = AutoModelForCausalLM.from_pretrained("zhixuan-lin/transformer-pro-760m-longcrawl64-48b") tokenizer = AutoTokenizer.from_pretrained("zhixuan-lin/transformer-pro-760m-longcrawl64-48b", add_bos_token=True, clean_up_tokenization_spaces=False) # Generation using HF api prompt = "The best thing to do in San Francisco is" model = model.cuda() encoded = tokenizer(prompt, return_tensors="pt").input_ids.cuda() with torch.autocast(device_type="cuda", dtype=torch.bfloat16): output = model.generate( encoded, max_new_tokens=30, )[0] pred = tokenizer.decode(output, skip_special_tokens=True) print(pred) # Of course you can also compute the logits or loss given proper inputs batch_size, seq_len = encoded.shape labels = encoded input_ids = torch.roll(labels, shifts=1, dims=-1) input_ids[:, 0] = tokenizer.bos_token_id # 50256 out = model(input_ids=input_ids, labels=labels) assert out.loss.size() == (batch_size, seq_len) # Logits are not returned (to save memory) if labels are given assert out.logits is None # To get logits don't provide labels out = model(input_ids=input_ids) assert out.logits.size() == (batch_size, seq_len, tokenizer.vocab_size) ``` ## Limitations This is a small model trained on a small number of tokens from LongCrawl64, provided for reproducibility and research purposes. Also, as a long-context dataset for research purposes, LongCrawl64 is not designed for optimal downstream task performance (it also has a strange tokenization process, see [here](https://github.com/zhixuan-lin/forgetting-transformer/blob/main/src/forgetting_transformer/tokenizer.py)). Therefore, this model is only suitable for research purposes (e.g., inspecting attention maps). Also, if you want to compare this model with other models trained in another setting with another dataset, **you should definitely train it from scratch on your own dataset under your own setting for the comparison.** ## Training Details ### Training Data This model is trained on roughly 48B tokens on LongCrawl64, with a training context length of 16k tokens. ### Training Procedure Please see [our paper](https://arxiv.org/abs/2503.02130) for details. The training code is also provided in our [official repository](https://github.com/zhixuan-lin/forgetting-transformer). **BibTeX:** ``` @inproceedings{ lin2025forgetting, title={Forgetting Transformer: Softmax Attention with a Forget Gate}, author={Zhixuan Lin and Evgenii Nikishin and Xu He and Aaron Courville}, booktitle={The Thirteenth International Conference on Learning Representations}, year={2025}, url={https://openreview.net/forum?id=q2Lnyegkr8} } ```