
Block-Recurrent Transformer



We introduce the Block-Recurrent Transformer, which applies a transformer layer in a recurrent fashion along a 
sequence, and has linear complexity with respect to sequence length. Our recurrent cell operates on blocks of tokens 
rather than single tokens, and leverages parallel computation within a block in order to make efficient use of accelerator 
hardware. The cell itself is strikingly simple. It is merely a transformer layer: it uses self-attention and cross-attention to 
efficiently compute a recurrent function over a large set of state vectors and tokens. Our design was inspired in part by 
LSTM cells, and it uses LSTM-style gates, but it scales the typical LSTM cell up by several orders of magnitude. Our 
implementation of recurrence has the same cost in both computation time and parameter count as a conventional 
transformer layer, but offers dramatically improved perplexity in language modeling tasks over very long sequences.
 Our model out-performs a long-range Transformer XL baseline by a wide margin, while running twice as fast. We 
 demonstrate its effectiveness on PG19 (books), arXiv papers, and GitHub source code.![Block-Recurrent 

Block-Recurrent Transformers-1
Block-Recurrent Transformers-2
Block-Recurrent Transformers-3
Block-Recurrent Transformers-4


$ pip install block-recurrent-transformer-pytorch


import torch
from block_recurrent_transformer_pytorch import BlockRecurrentTransformer

model = BlockRecurrentTransformer(
    num_tokens = 20000,             # vocab size
    dim = 512,                      # model dimensions
    depth = 6,                      # depth
    dim_head = 64,                  # attention head dimensions
    heads = 8,                      # number of attention heads
    max_seq_len = 1024,             # the total receptive field of the transformer, in the paper this was 2 * block size
    block_width = 512,              # block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)
    num_state_vectors = 512,        # number of state vectors, i believe this was a single block size in the paper, but can be any amount
    recurrent_layers = (4,),        # where to place the recurrent layer(s) for states with fixed simple gating
    use_compressed_mem = False,     # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507
    compressed_mem_factor = 4,      # compression factor of compressed memories
    use_flash_attn = True           # use flash attention, if on pytorch 2.0

seq = torch.randint(0, 2000, (1, 1024))

out, mems1, states1 = model(seq)
out, mems2, states2 = model(seq, xl_memories = mems1, states = states1)
out, mems3, states3 = model(seq, xl_memories = mems2, states = states2)
