```
import os
import trax
from trax import layers as tl # core building block
import jax
from trax import fastmath # uses jax, offers numpy on steroids
# fastmath.use_backend('tensorflow-numpy')
import functools
from trax.fastmath import numpy as np # note, using fastmath subset of numpy!
from trax.layers import (
#tie_in,
length_normalized,
apply_broadcasted_dropout,
look_adjacent,
permute_via_gather,
permute_via_sort,
)
from jax.lax import tie_in
```

## 1 Introduction

Two ‘reforms’ can make the Transformer more memory and compute efficient. The *Reversible Layers* reduce memory and *Locality Sensitive Hashing (LSH)* reduces the cost of the Dot Product attention for large input sizes. In this article we will look more closely at LSH and how it is used in the Reformer model.

Specifically, we will look at:

- review dot-product self attention for reference
- examine LSH based self attention
- extend our understanding and familiarity with Trax infrastructure

## 2 Trax Efficient Attention classes

Trax is similar to other popular NN development platforms such as Keras (now integrated into Tensorflow) and Pytorch in that it uses ‘layers’ as a useful level of abstraction. Layers are often represented as *classes*. We’re going to improve our understanding of Trax by locally extending the classes used in the attention layers. We will extend only the ‘forward’ functions and utilize the existing attention layers as parent classes. The original code can be found at github:trax/layers/Research/Efficient_attention. This link references release 1.3.9 but note that this is under the ‘research’ directory as this is an area of active research. When accessing the code on Github for review on this assignment, be sure you select the 1.3.9 release tag, the master copy may have new changes.

**Figure 1: Reference Tag 1.3.9 on github**

Let’s spend a few moments reviewing the classes we will be using.

**Figure 2: Classes from Trax/layers/Research/Efficient_Attention.py that we will be utilizing.**

Starting on the right in the diagram above you see `SelfAttention`

that is a ‘traditional’ implementation of the dot product attention. The parent to this class is the `base.layer`

which has the routines used by all layers. `SelfAttention`

has an important feature in the *Forward* routine. It supports a `use_reference_code`

capability that selects implementations that limit some of the complexities to provide a more easily understood version of the algorithms. In particular, it implements a nested loop that treats each *‘example, head’* independently. This simplifies our work as we need only worry about matrix operations on one *‘example, head’* at a time. This loop calls *forward_unbatched*, which is the child process that we will be overriding.

We will be implementing the *forward_unbatched* version of `SelfAttention`

to highlight the differences between this and the LSH implementation.

On the top left is the `LSHSelfAttention`

. This is the routine used in the Reformer architecture. We will override the *forward_unbatched* section of this and some of the utility functions it uses to explore its implementation in more detail.

The code we will be working with is from the Trax source, and as such has implementation details that will make it a bit harder to follow. However, it will allow use of the results along with the rest of the Trax infrastructure. I will try to briefly describe these as they arise. The Trax documentation can also be referenced.

### 2.1 Trax Details

The goal in this article is to override a few routines in the Trax classes with our own versions. To maintain their functionality in a full Trax environment, many of the details we might ignore in example version of routines will be maintained in this code. Here are some of the considerations that may impact our code:

- Trax operates with multiple back-end libraries, we will see special cases that will utilize unique features.
- ‘Fancy’ numpy indexing is not supported in all backend environments and must be emulated in other ways.
- Some operations don’t have gradients for backprop and must be ignored or include forced re-evaluation.

Here are some of the functions we may see:

- Abstracted as
`fastmath`

, Trax supports multiple backends such as Jax and Tensorflow2 - tie_in: Some non-numeric operations must be invoked during backpropagation. Normally, the gradient compute graph would determine invocation but these functions are not included. To force re-evaluation, they are ‘tied’ to other numeric operations using tie_in.
- stop_gradient: Some operations are intentionally excluded from backprop gradient calculations by setting their gradients to zero.
- Below we will execute
`from trax.fastmath import numpy as np`

, this uses accelerated forms of numpy functions. This is, however a*subset*of numpy

## 3 Full Dot-Product Self Attention

### 3.1 Description

**Figure 3: Project datapath and primary data structures and where they are implemented**

The diagram above shows many of the familiar data structures and operations related to attention and describes the routines in which they are implemented. We will start by working on *our_simple_attend* or our simpler version of the original *attend* function. We will review the steps in performing dot-product attention with more focus on the details of the operations and their significance. This is useful when comparing to LSH attention. Note we will be discussing a single example/head unless otherwise specified.

**Figure 4: dot-product of Query and Key**

The *attend* function receives *Query* and *Key*. As a reminder, they are produced by a matrix multiply of all the inputs with a single set of weights. We will describe the inputs as *embeddings* assuming an NLP application, however, this is not required. This matrix multiply works very much like a convolutional network where a set of weights (a filter) slides across the input vectors leaving behind a map of the similarity of the input to the filter. In this case, the filters are the weight matrices \(W^Q\) and \(W^K\). The resulting maps are Q and K. Q and K have the dimensions of (n_seq, n_q) where n_seq is the number of input embeddings and n_q or n_k is the selected size of the Q or K vectors. Note the shading of Q and K, this reflects the fact that each entry is associated with a particular input embedding. You will note later in the code that K is optional. Apparently, similar results can be achieved using Query alone saving the compute and storage associated with K. In that case, the dot-product in *attend* is matmul(q,q). Note the resulting dot-product (*Dot*) entries describe a complete (n_seq,n_seq) map of the similarity of all entries of q vs all entries of k. This is reflected in the notation in the dot-product boxes of \(w_n\),\(w_m\) representing word_n, word_m. Note that each row of *Dot* describes the relationship of an input embedding, say \(w_0\), with every other input.

In some applications some values are masked. This can be used, for example to exclude results that occur later in time (causal) or to mask padding or other inputs.

**Figure 5: Masking**

The routine below *mask_self_attention* implements a flexible masking capability. The masking is controlled by the information in q_info and kv_info.

```
def mask_self_attention(
=True, exclude_self=True, masked=False
dots, q_info, kv_info, causal
):"""Performs masking for self-attention."""
if causal:
= fastmath.lt(q_info, kv_info).astype(np.float32)
mask = dots - 1e9 * mask
dots if exclude_self:
= np.equal(q_info, kv_info).astype(np.float32)
mask = dots - 1e5 * mask
dots if masked:
= tie_in(kv_info, np.zeros_like(kv_info))
zeros_like_kv_info = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32)
mask = dots - 1e9 * mask
dots return dots
```

A SoftMax is applied per row of the *Dot* matrix to scale the values in the row between 0 and 1.

**Figure 6: SoftMax per row of Dot**

### 3.2 our_softmax

This code uses a separable form of the softmax calculation. Recall the softmax: \[ softmax(x_i)=\frac{\exp(x_i)}{\sum_j \exp(x_j)}\tag{1}\] This can be alternately implemented as: \[ logsumexp(x)=\log{({\sum_j \exp(x_j)})}\tag{2}\] \[ softmax(x_i)=\exp({x_i - logsumexp(x)})\tag{3}\] The work below will maintain a copy of the logsumexp allowing the softmax to be completed in sections. You will see how this is useful later in the LSHSelfAttention class. We’ll create a routine to implement that here with the addition of a passthrough. The matrix operations we will be working on below are easier to follow if we can maintain integer values. So, for tests, we will skip the softmax in some cases.

```
def our_softmax(x, passthrough=False):
""" softmax with passthrough"""
= fastmath.logsumexp(x, axis=-1, keepdims=True)
logsumexp = np.exp(x - logsumexp)
o if passthrough:
return (x, np.zeros_like(logsumexp))
else:
return (o, logsumexp)
```

Let’s check our implementation.

```
## compare softmax(a) using both methods
= np.array([1.0, 2.0, 3.0, 4.0])
a = np.exp(a) / sum(np.exp(a))
sma print(sma)
= our_softmax(a)
sma2, a_logsumexp print(sma2)
print(a_logsumexp)
```

```
[0.0320586 0.08714432 0.2368828 0.6439142 ]
[0.0320586 0.08714431 0.23688279 0.64391416]
[4.44019]
```

The purpose of the dot-product is to ‘focus attention’ on some of the inputs. Dot now has entries appropriately scaled to enhance some values and reduce others. These are now applied to the \(V\) entries.

**Figure 7: Applying Attention to \(V\)**

\(V\) is of size (n_seq,n_v). Note the shading in the diagram. This is to draw attention to the operation of the matrix multiplication. This is detailed below.

**Figure 7: The Matrix Multiply applies attention to the values of V**

\(V\) is formed by a matrix multiply of the input embedding with the weight matrix \(W^v\) whose values were set by backpropagation. The row entries of \(V\) are then related to the corresponding input embedding. The matrix multiply weights first column of V, representing a section of each of the input embeddings, with the first row of Dot, representing the similarity of \(W_0\) and each word of the input embedding and deposits the value in \(Z\)

### 3.3 our_simple_attend

In this section we’ll work on an implementation of *attend* whose operations you can see in figure 3. It is a slightly simplified version of the routine in efficient_attention.py. We will fill in a few lines of code. The main goal is to become familiar with the routine.

```
def our_simple_attend(
q,=None,
k=None,
v=None,
mask_fn=None,
q_info=None,
kv_info=0.0,
dropout=None,
rng=False,
verbose=False,
passthrough
):"""Dot-product attention, with masking, without optional chunking and/or.
Args:
q: Query vectors, shape [q_len, d_qk]
k: Key vectors, shape [kv_len, d_qk]; or None
v: Value vectors, shape [kv_len, d_v]
mask_fn: a function reference that implements masking (e.g. mask_self_attention)
q_info: Query-associated metadata for masking
kv_info: Key-associated metadata for masking
dropout: Dropout rate
rng: RNG for dropout
Returns:
A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
dots_logsumexp has shape [q_len]. The logsumexp of the attention
probabilities is useful for combining multiple rounds of attention (as in
LSH attention).
"""
assert v is not None
= k is None
share_qk if share_qk:
= q
k if kv_info is None:
= q_info
kv_info
if share_qk:
= length_normalized(k)
k = k / np.sqrt(k.shape[-1])
k
# Dot-product attention.
= np.swapaxes(k, -1, -2) # note the fancy transpose for later..
kr
## Step 1 ##
= np.matmul(q, kr )
dots if verbose:
print("Our attend dots", dots.shape)
# Masking
if mask_fn is not None:
= mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])
dots
# Softmax.
# dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True) #original
# dots = np.exp(dots - dots_logsumexp) #original
## Step 2 ##
# replace with our_softmax()
= our_softmax(dots, passthrough=passthrough)
dots, dots_logsumexp if verbose:
print("Our attend dots post softmax", dots.shape, dots_logsumexp.shape)
if dropout > 0.0:
assert rng is not None
# Dropout is broadcast across the bin dimension
= (dots.shape[-2], dots.shape[-1])
dropout_shape = tie_in(dots, 1.0 - dropout)
keep_prob = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
keep = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
multiplier = dots * multiplier
dots
## Step 3 ##
# The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
= np.matmul(dots, v)
out if verbose:
print("Our attend out1", out.shape)
= np.reshape(out, (-1, out.shape[-1]))
out if verbose:
print("Our attend out2", out.shape)
= np.reshape(dots_logsumexp, (-1,))
dots_logsumexp return out, dots_logsumexp
```

```
= 8
seq_len = 5
emb_len = 3
d_qk = 4
d_v with fastmath.use_backend("jax"): # specify the backend for consistency
= fastmath.random.get_prng(1)
rng_attend = k = jax.random.uniform(rng_attend, (seq_len, d_qk), dtype=np.float32)
q = jax.random.uniform(rng_attend, (seq_len, d_v), dtype=np.float32)
v = our_simple_attend(
o, logits
q,
k,
v,=None,
mask_fn=None,
q_info=None,
kv_info=0.0,
dropout=rng_attend,
rng=True,
verbose
)print(o, "\n", logits)
```

```
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
[[0.5606322 0.7290603 0.52512413 0.47101063]
[0.5713517 0.71991956 0.5033342 0.46975708]
[0.5622886 0.7288458 0.52172124 0.46318397]
[0.55683166 0.72234154 0.542236 0.46997216]
[0.56504494 0.72274375 0.5204978 0.47231334]
[0.56175965 0.7216782 0.53293145 0.48003793]
[0.56753993 0.72232544 0.5141734 0.46625748]
[0.57100445 0.70785505 0.5325362 0.4590797 ]]
[2.6512177 2.1914332 2.6630518 2.7792363 2.4583826 2.5421977 2.4145055
2.5111294]
```

## 4 Class OurSelfAttention

Here we create our own self attention layer by creating a class `OurSelfAttention`

. The parent class will be the tl.SelfAttention layer in Trax. We will only override the `forward_unbatched`

routine.

```
class OurSelfAttention(tl.SelfAttention):
"""Our self-attention. Just the Forward Function."""
def forward_unbatched(
self, x, mask=None, *, weights, state, rng, update_state, verbose=False
):print("ourSelfAttention:forward_unbatched")
del update_state
= fastmath.random.split(rng)
attend_rng, output_rng if self._bias:
if self._share_qk:
= weights
w_q, w_v, w_o, b_q, b_v else:
= weights
w_q, w_k, w_v, w_o, b_q, b_k, b_v else:
if self._share_qk:
= weights
w_q, w_v, w_o else:
= weights
w_q, w_k, w_v, w_o
print("x.shape,w_q.shape", x.shape, w_q.shape)
= np.matmul(x, w_q)
q = None
k if not self._share_qk:
= np.matmul(x, w_k)
k = np.matmul(x, w_v)
v
if self._bias:
= q + b_q
q if not self._share_qk:
= k + b_k
k = v + b_v
v
= functools.partial(
mask_fn
mask_self_attention,=self._causal,
causal=self._share_qk,
exclude_self=self._masked,
masked
)= kv_info = tie_in(x, np.arange(q.shape[-2], dtype=np.int32))
q_info
assert (mask is not None) == self._masked
if self._masked:
# mask is a boolean array (True means "is valid token")
= tie_in(x, np.ones_like(mask, dtype=np.int32))
ones_like_mask = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)
kv_info
# Notice, we are calling our version of attend
= our_simple_attend(
o, _
q,
k,
v,=mask_fn,
mask_fn=q_info,
q_info=kv_info,
kv_info=self._attention_dropout,
dropout=attend_rng,
rng=True,
verbose
)
# Notice, wo weight matrix applied to output of attend in forward_unbatched
= np.matmul(o, w_o)
out = apply_broadcasted_dropout(out, self._output_dropout, output_rng)
out return out, state
```

```
= False
causal = False
masked = None
mask = 0.0
attention_dropout = 3
n_heads = 3
d_qk = 4
d_v = 8
seq_len = 5
emb_len = 1
batch_size
= OurSelfAttention(
osa =n_heads,
n_heads=d_qk,
d_qk=d_v,
d_v=causal,
causal=True,
use_reference_code=attention_dropout,
attention_dropout="train",
mode
)
= fastmath.random.get_prng(1)
rng_osa = jax.random.uniform(
x 0), (batch_size, seq_len, emb_len), dtype=np.float32
jax.random.PRNGKey(
)= osa.init(tl.shapes.signature(x), rng=rng_osa) _, _
```

` osa(x)`

```
ourSelfAttention:forward_unbatched
x.shape,w_q.shape (8, 5) (5, 3)
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
ourSelfAttention:forward_unbatched
x.shape,w_q.shape (8, 5) (5, 3)
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
ourSelfAttention:forward_unbatched
x.shape,w_q.shape (8, 5) (5, 3)
Our attend dots (8, 8)
Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
```

```
DeviceArray([[[ 6.70414209e-01, -1.04319841e-01, -5.33822298e-01,
1.92711830e-01, -4.54187393e-05],
[ 6.64090097e-01, -1.01875424e-01, -5.35733163e-01,
1.88311756e-01, -6.30629063e-03],
[ 6.73380017e-01, -1.06952369e-01, -5.31989932e-01,
1.90056756e-01, 1.30271912e-03],
[ 6.84564888e-01, -1.13240272e-01, -5.50182462e-01,
1.95673436e-01, 5.47638535e-03],
[ 6.81435883e-01, -1.11068964e-01, -5.32343209e-01,
1.91912338e-01, 5.69400191e-03],
[ 6.80724978e-01, -1.08496904e-01, -5.34994125e-01,
1.96332246e-01, 5.89773059e-03],
[ 6.80933356e-01, -1.14087075e-01, -5.18659890e-01,
1.90674111e-01, 1.14096105e-02],
[ 6.80265009e-01, -1.09031796e-01, -5.38248718e-01,
1.94203183e-01, 4.23943996e-03]]], dtype=float32)
```

## 5 Trax LSHSelfAttention

### 5.1 Description

The larger the matrix multiply in the previous section is, the more context can be taken into account when making the next decision. However, the self attention dot product grows as the size of the input squared. For example, if one wished to have an input size of 1024, that would result in \(1024^2\) or over a million dot products for each head! As a result, there has been significant research related to reducing the compute requirements. One such approach is Locality Sensitive Hashing (LSH) Self Attention.

We previously utilized LSH to find similar tweets without resorting to calculating cosine similarity for each pair of embeddings. We will use a similar approach here. It may be best described with an example.

**Figure 9: Example of LSH Self Attention**

LSH Self attention uses Queries only, no Keys. Attention then generates a metric of the similarity of each value of Q relative to all the other values in Q. An earlier article demonstrated that values which hash to the same bucket are likely to be similar. Further, multiple random hashes can improve the chances of finding entries which are similar. This is the approach taken here, though the hash is implemented a bit differently. The values of Q are hashed into buckets using a randomly generated set of hash vectors. Multiple sets of hash vectors are used, generating multiple hash tables. In the figure above, we have 3 hash tables with 4 buckets in each table. Notionally, following the hash, the values of Q have been replicated 3 times and distributed to their appropriate bucket in each of the 3 tables. To find similarity then, one generates dot-products only between members of the buckets. The result of this operation provides information on which entries are similar. As the operation has been distributed over multiple hash tables, these results need to be combined to form a complete picture and this can be used to generate a reduced dot-product attention array. Its clear that because we do not do a compare of every value vs every other value, the size of *Dots* will be reduced.

The challenge in this approach is getting it to operate efficiently. In earlier projects the buckets were lists of entries and had varying length. This will operate poorly on a vector processing machine such as a GPU or TPU. Ideally, operations are done in large blocks with uniform sizes. While it is straightforward to implement the hash algorithm this way, it is challenging to managed buckets and variable sized dot-products. This will be discussed further below. For now, we will examine and implement the hash function.

### 5.2 our_hash_vectors

*our_hash_vectors*, is a reimplementation of Trax *hashvector*. It takes in an array of vectors, hashes the entries and returns and array assigning each input vector to `n_buckets`

buckets. Hashing is described as creating *random rotations*, see Practical and Optimal LSH for Angular Distance.

**Figure 10: Processing steps in our_hash_vectors**

Note, in the diagram, sizes relate to our expected input \(Q\) while our_hash_vectors is written assuming a generic input vector

**Step 1** create an array of random normal vectors which will be our hash vectors. Each vector will be hashed into a hash table and into `rot_size//2`

buckets. We use `rot_size//2`

to reduce computation. Later in the routine we will form the negative rotations with a simple negation and concatenate to get a full `rot_size`

number of rotations.

- use fastmath.random.normal and create an array of random vectors of shape
`(vecs.shape[-1],n_hashes, rot_size//2)`

**Step 2** In this step we simply do the matrix multiply. `jax`

has an accelerated version of einsum. Here we will utilize more conventional routines.

**Step 2x**

- 2a:
`np.reshape`

random_rotations into a 2 dimensional array (`[-1, n_hashes * (rot_size // 2)]`

) - 2b:
`np.dot`

vecs and random_rotations forming our rotated_vecs - 2c: back to 3 dimension with
`np.reshape`

`[-1, n_hashes, rot_size//2]`

- 2d: prepare for concatenating by swapping dimensions np.transpose
`(1, 0, 2)`

**Step 3** Here we concatenate our rotation vectors getting a fullrot_size number of buckets (note, n_buckets = rotsize) * use `np.concatenate`

, `[rotated_vecs, -rotated_vecs]`

, `axis=-1`

**Step 4** **This is the exciting step!** You have no doubt been wondering how we will turn these vectors into bucket indexes. By performing `np.argmax`

over the rotations for a given entry, you get the index to the best match! We will use this as a bucket index. * `np.argmax(...).astype(np.int32)`

; be sure to use the correct axis!

**Step 5** In this style of hashing, items which land in bucket 0 of hash table 0 are not necessarily similar to those landing in bucket 0 of hash table 1, so we keep them separate. We do this by offsetting the bucket numbers by `n_buckets`

. * add buckets and offsets and reshape into a one dimensional array. This will return a 1D array of size `n_hashes * vec.shape[0]`

.

```
def our_hash_vectors(vecs, rng, n_buckets, n_hashes, mask=None, verbose=False):
"""
Args:
vecs: tensor of at least 2 dimension,
rng: random number generator
n_buckets: number of buckets in each hash table
n_hashes: the number of hash tables
mask: None indicating no mask or a 1D boolean array of length vecs.shape[0], containing the location of padding value
verbose: controls prints for debug
Returns:
A vector of size n_hashes * vecs.shape[0] containing the buckets associated with each input vector per hash table.
"""
# check for even, integer bucket sizes
assert isinstance(n_buckets, int) and n_buckets % 2 == 0
= fastmath.stop_gradient(tie_in(vecs, rng))
rng = n_buckets
rot_size
### Step 1 ###
= (vecs.shape[-1], n_hashes, rot_size // 2)
rotations_shape = fastmath.random.normal(rng, rotations_shape).astype(
random_rotations
np.float32)if verbose: print("random.rotations.shape", random_rotations.shape)
### Step 2 ###
if fastmath.backend_name() == 'jax':
= np.einsum('tf,fhb->htb', vecs, random_rotations)
rotated_vecs if verbose: print("using jax")
else:
#Step 2a
= np.reshape(random_rotations,
random_rotations -1, n_hashes * (rot_size // 2)])
[if verbose: print("random_rotations reshaped", random_rotations.shape)
#Step 2b
= np.dot(vecs, random_rotations)
rotated_vecs if verbose: print("rotated_vecs1", rotated_vecs.shape)
#Step 2c
= np.reshape(rotated_vecs, [-1, n_hashes, rot_size//2])
rotated_vecs if verbose: print("rotated_vecs2", rotated_vecs.shape)
#Step 2d
= np.transpose(rotated_vecs, (1, 0, 2))
rotated_vecs if verbose: print("rotated_vecs3", rotated_vecs.shape)
### Step 3 ###
= np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
rotated_vecs if verbose: print("rotated_vecs.shape", rotated_vecs.shape)
### Step 4 ###
= np.argmax(rotated_vecs, axis=-1).astype(np.int32)
buckets if verbose: print("buckets.shape", buckets.shape)
if verbose: print("buckets", buckets)
if mask is not None:
+= 1 # Create an extra bucket for padding tokens only
n_buckets = np.where(mask[None, :], buckets, n_buckets - 1)
buckets
# buckets is now (n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
= tie_in(buckets, np.arange(n_hashes, dtype=np.int32))
offsets = np.reshape(offsets * n_buckets, (-1, 1))
offsets ### Step 5 ###
= np.reshape(buckets + offsets, (-1,))
buckets if verbose: print("buckets with offsets", buckets.shape, "\n", buckets)
return buckets
```

```
# example code. Note for reference, the sizes in this example match the values in the diagram above.
= np.ones((8, 5)) # (seq_len=8, n_q=5)
ohv_q = 4 # even number
ohv_n_buckets = 3
ohv_n_hashes
with fastmath.use_backend("tensorflow-numpy"):
= fastmath.random.get_prng(1)
ohv_rng = our_hash_vectors(
ohv =None, verbose=True
ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask
)print("ohv shape", ohv.shape, "\nohv", ohv) # (ohv_n_hashes * ohv_n_buckets)
# note the random number generators do not produce the same results with different backends
with fastmath.use_backend("jax"):
= fastmath.random.get_prng(1)
ohv_rng = our_hash_vectors(ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None)
ohv print("ohv shape", ohv.shape, "\nohv", ohv) # (ohv_n_hashes * ohv_n_buckets)
```

```
random.rotations.shape (5, 3, 2)
random_rotations reshaped (5, 6)
rotated_vecs1 (8, 6)
rotated_vecs2 (8, 3, 2)
rotated_vecs3 (3, 8, 2)
rotated_vecs.shape (3, 8, 4)
buckets.shape (3, 8)
buckets tf.Tensor(
[[3 3 3 3 3 3 3 3]
[3 3 3 3 3 3 3 3]
[3 3 3 3 3 3 3 3]], shape=(3, 8), dtype=int32)
buckets with offsets (24,)
tf.Tensor([ 3 3 3 3 3 3 3 3 7 7 7 7 7 7 7 7 11 11 11 11 11 11 11 11], shape=(24,), dtype=int32)
ohv shape (24,)
ohv tf.Tensor([ 3 3 3 3 3 3 3 3 7 7 7 7 7 7 7 7 11 11 11 11 11 11 11 11], shape=(24,), dtype=int32)
ohv shape (24,)
ohv [ 3 3 3 3 3 3 3 3 5 5 5 5 5 5 5 5 11 11 11 11 11 11 11 11]
```

### 5.3 Sorting Buckets

Now that we have a hash function, we can work on sorting our buckets and performing our matrix operations. We’ll walk through this algorithm in small steps: * sort_buckets - we’ll perform the sort * softmax * dotandv - do the matrix math to form the dotproduct and output

These routines will demonstrate a simplified version of the algorithm. We won’t address masking and variable bucket sizes but will consider how they would be handled.

**sort_buckets**

At this point, we have called the hash function and were returned the associated buckets. For example, if we started with `q[n_seq,n_q]`

, with `n_hash = 2; n_buckets = 4; n_seq = 8`

we might be returned: `bucket = [0,1,2,3,0,1,2,3, 4,5,6,7,4,5,6,7]`

. Note that it is `n_hash * n_seq`

long and that the bucket values for each hash have been offset by `n_buckets`

so the numbers do not overlap. Going forward, we are going to sort this array of buckets to group together members of the same (hash,bucket) pair.

**Step 1** Our goal is to sort \(q\) rather than the bucket list, so we will need to track the association of the buckets to their elements in \(q\). * using `np.arange`

, create `ticker`

, just a sequence of numbers (0…n_hashes * seqlen) associating members of \(q\) with their bucket.

**Step 2** We want to disambiguate elements that map to the same bucket. When a sorting routine encounters a situation where multiple entries have the same value, it can correctly choose any entry to go first. This makes testing ambiguous. This prevents that. We multiply all the buckets by `seqlen`

and then add `ticker % seqlen`

**Step 3** Here we are! Ready to sort. This is the exciting part. * Utilize fastmath.sort_key_val and sort `buckets_and_t`

and `ticker`

.

**Step 4** We need to be able to undo the sort at the end to get things back into their correct locations * sort `sticker`

and `ticker`

to for the reverse map

**Step 5** create our sorted q and sorted v * use np.take and `st`

to grab correct values in `q`

for the sorted values, `sq`

. Use `axis=0`

.

```
def sort_buckets(buckets, q, v, n_buckets, n_hashes, seqlen, verbose=True):
"""
Args:
buckets: tensor of at least 2 dimension,
n_buckets: number of buckets in each hash table
n_hashes: the number of hash tables
"""
if verbose: print("---sort_buckets--")
## Step 1
= np.arange(n_hashes * seqlen)
ticker if verbose: print("ticker",ticker.shape, ticker)
## Step 2
= seqlen * buckets + (ticker % seqlen)
buckets_and_t if verbose: print("buckets_and_t",buckets_and_t.shape, buckets_and_t)
# Hash-based sort ("s" at the start of variable names means "sorted")
#Step 3
= fastmath.sort_key_val(
sbuckets_and_t, sticker =-1)
buckets_and_t, ticker, dimensionif verbose: print("sbuckets_and_t",sbuckets_and_t.shape, sbuckets_and_t)
if verbose: print("sticker",sticker.shape, sticker)
#Step 4
= fastmath.sort_key_val(sticker, ticker, dimension=-1)
_, undo_sort if verbose: print("undo_sort",undo_sort.shape, undo_sort)
#Step 4
= (sticker % seqlen)
st = np.take(q, st, axis=0)
sq = np.take(v, st, axis=0)
sv return sq, sv, sticker, undo_sort
```

```
= 2
t_n_hashes = 4
t_n_buckets = t_seqlen = 8
t_n_seq = 3
t_n_q = 5
n_v
= (np.array([(j % t_n_buckets) for j in range(t_n_seq)]) * np.ones((t_n_q, 1))).T
t_q = np.ones((t_n_seq, n_v))
t_v = np.array(
t_buckets
[% t_n_buckets) + t_n_buckets * i
(j for i in range(t_n_hashes)
for j in range(t_n_seq)
]
)print("q\n", t_q)
print("t_buckets: ", t_buckets)
= sort_buckets(
t_sq, t_sv, t_sticker, t_undo_sort =True
t_buckets, t_q, t_v, t_n_buckets, t_n_hashes, t_seqlen, verbose
)
print("sq.shape", t_sq.shape, "sv.shape", t_sv.shape)
print("sq\n", t_sq)
```

```
q
[[0. 0. 0.]
[1. 1. 1.]
[2. 2. 2.]
[3. 3. 3.]
[0. 0. 0.]
[1. 1. 1.]
[2. 2. 2.]
[3. 3. 3.]]
t_buckets: [0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7]
---sort_buckets--
ticker (16,) [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
buckets_and_t (16,) [ 0 9 18 27 4 13 22 31 32 41 50 59 36 45 54 63]
sbuckets_and_t (16,) [ 0 4 9 13 18 22 27 31 32 36 41 45 50 54 59 63]
sticker (16,) [ 0 4 1 5 2 6 3 7 8 12 9 13 10 14 11 15]
undo_sort (16,) [ 0 2 4 6 1 3 5 7 8 10 12 14 9 11 13 15]
sq.shape (16, 3) sv.shape (16, 5)
sq
[[0. 0. 0.]
[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]
[2. 2. 2.]
[2. 2. 2.]
[3. 3. 3.]
[3. 3. 3.]
[0. 0. 0.]
[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]
[2. 2. 2.]
[2. 2. 2.]
[3. 3. 3.]
[3. 3. 3.]]
```

### 5.4 Chunked dot product attention

Now let’s create the dot product attention. We have sorted \(Q\) so that elements that the hash has determined are likely to be similar are adjacent to each other. We now want to perform the dot-product within those limited regions - in ‘chunks’.

**Figure 11: Performing dot product in ‘chunks’**

The example we have been working on is shown above, with sequences of 8, 2 hashes, 4 buckets and, conveniently, the content of Q was such that when sorted, there were 2 entries in each bucket. If we reshape Q into a (8,2,n_q), we can use numpy matmul to perform the operation. Numpy matmul will treat the inputs as a stack of matrices residing in the last two indexes. This will allow us to matrix multiply Q with itself in *chunks* and later can also be used to perform the matrix multiply with v.

We will perform a softmax on the output of the dot product of Q and Q, but in this case, there is a bit more to the story. Recall the output of the hash had multiple hash tables. We will perform softmax on those separately and then must combine them. This is where the form of softmax we defined at the top of the notebook comes into play. The routines below will utilize the `logsumexp`

values that the `our_softmax`

routine calculates.

There is a good deal of reshaping to get things into the right formats. The code has many `print`

statements that match the expected values below. You can use those to check your work as you go along. If you don’t do a lot of 3-dimensional matrix multiplications in your daily life, it might be worthwhile to open a spare cell and practice a few simple examples to get the hang of it! Here is one to start with:

```
= np.arange(16 * 3).reshape((16, 3))
a = 2
chunksize = np.reshape(
ar -1, chunksize, a.shape[-1])
a, (# the -1 usage is very handy, see numpy reshape
) print(ar.shape)
```

`(8, 2, 3)`

**Step 1** Reshaping Q * np.reshape `sq`

(sorted q) to be 3 dimensions. The middle dimension is the size of the ‘chunk’ specified by `kv_chunk_len`

* np.swapaxes to perform a ‘transpose’ on the reshaped `sq`

, *but only on the last two dimensions* * np.matmul the two values.

**Step 2** * use our_softmax to perform the softmax on the dot product. Don’t forget `passthrough`

**Step 3** * np.reshape `sv`

. Like `sq`

, the middle dimension is the size of the ‘chunk’ specified by `kv_chunk_len`

* np.matmul dotlike and the reshaped `sv`

* np.reshape `so`

to a two dimensional array with the last dimension stays the same (`so.shape[-1]`

) * `logits`

also needs reshaping, we’ll do that.

**Step 4** Now we can undo the sort. * use np.take and `undo_sort`

and `axis = 0`

to unsort so * do the same with `slogits`

.

**Step 5** This step combines the results of multiple hashes. Recall, the softmax was only over the values in one hash, this extends it to all the hashes. Read through it, the code is provided. Note this is taking place *after* the matrix multiply with v while the softmax output is used before the multiply.

```
def dotandv(sq, sv, undo_sort, kv_chunk_len, n_hashes, seqlen, passthrough, verbose=False ):
# Step 1
= np.reshape(sq,(-1, kv_chunk_len, sq.shape[-1]))
rsq = np.swapaxes(rsq, -1, -2)
rsqt if verbose: print("rsq.shape,rsqt.shape: ", rsq.shape,rsqt.shape)
= np.matmul(rsq, rsqt)
dotlike if verbose: print("dotlike\n", dotlike)
#Step 2
= our_softmax(dotlike, passthrough)
dotlike, slogits if verbose: print("dotlike post softmax\n", dotlike)
#Step 3
= np.reshape(sv, (-1, kv_chunk_len, sv.shape[-1]))
vr if verbose: print("dotlike.shape, vr.shape:", dotlike.shape, vr.shape)
= np.matmul(dotlike, vr)
so if verbose: print("so.shape:", so.shape)
= np.reshape(so, (-1, so.shape[-1]))
so = np.reshape(slogits, (-1,)) # provided
slogits if verbose: print("so.shape,slogits.shape", so.shape, slogits.shape)
#Step 4
= np.take(so, undo_sort, axis=0)
o = np.take(slogits, undo_sort, axis=0)
logits if verbose: print("o.shape,o", o.shape, o)
if verbose: print("logits.shape, logits", logits.shape, logits)
#Step 5
if n_hashes > 1:
= np.reshape(o, (n_hashes, seqlen, o.shape[-1]))
o = np.reshape(logits, (n_hashes, seqlen, 1))
logits = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
probs = np.sum(o * probs, axis=0)
o
return(o)
```

```
= 2
t_kv_chunk_len = dotandv(
out
t_sq,
t_sv,
t_undo_sort,
t_kv_chunk_len,
t_n_hashes,
t_seqlen,=True,
passthrough=True,
verbose
)print("out\n", out)
print("\n-----With softmax enabled----\n")
= dotandv(
out
t_sq,
t_sv,
t_undo_sort,
t_kv_chunk_len,
t_n_hashes,
t_seqlen,=False,
passthrough=True,
verbose
)print("out\n", out)
```

```
rsq.shape,rsqt.shape: (8, 2, 3) (8, 3, 2)
dotlike
[[[ 0. 0.]
[ 0. 0.]]
[[ 3. 3.]
[ 3. 3.]]
[[12. 12.]
[12. 12.]]
[[27. 27.]
[27. 27.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 3. 3.]
[ 3. 3.]]
[[12. 12.]
[12. 12.]]
[[27. 27.]
[27. 27.]]]
dotlike post softmax
[[[ 0. 0.]
[ 0. 0.]]
[[ 3. 3.]
[ 3. 3.]]
[[12. 12.]
[12. 12.]]
[[27. 27.]
[27. 27.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 3. 3.]
[ 3. 3.]]
[[12. 12.]
[12. 12.]]
[[27. 27.]
[27. 27.]]]
dotlike.shape, vr.shape: (8, 2, 2) (8, 2, 5)
so.shape: (8, 2, 5)
so.shape,slogits.shape (16, 5) (16,)
o.shape,o (16, 5) [[ 0. 0. 0. 0. 0.]
[ 6. 6. 6. 6. 6.]
[24. 24. 24. 24. 24.]
[54. 54. 54. 54. 54.]
[ 0. 0. 0. 0. 0.]
[ 6. 6. 6. 6. 6.]
[24. 24. 24. 24. 24.]
[54. 54. 54. 54. 54.]
[ 0. 0. 0. 0. 0.]
[ 6. 6. 6. 6. 6.]
[24. 24. 24. 24. 24.]
[54. 54. 54. 54. 54.]
[ 0. 0. 0. 0. 0.]
[ 6. 6. 6. 6. 6.]
[24. 24. 24. 24. 24.]
[54. 54. 54. 54. 54.]]
logits.shape, logits (16,) [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
out
[[ 0. 0. 0. 0. 0.]
[ 6. 6. 6. 6. 6.]
[24. 24. 24. 24. 24.]
[54. 54. 54. 54. 54.]
[ 0. 0. 0. 0. 0.]
[ 6. 6. 6. 6. 6.]
[24. 24. 24. 24. 24.]
[54. 54. 54. 54. 54.]]
-----With softmax enabled----
rsq.shape,rsqt.shape: (8, 2, 3) (8, 3, 2)
dotlike
[[[ 0. 0.]
[ 0. 0.]]
[[ 3. 3.]
[ 3. 3.]]
[[12. 12.]
[12. 12.]]
[[27. 27.]
[27. 27.]]
[[ 0. 0.]
[ 0. 0.]]
[[ 3. 3.]
[ 3. 3.]]
[[12. 12.]
[12. 12.]]
[[27. 27.]
[27. 27.]]]
dotlike post softmax
[[[0.5 0.5 ]
[0.5 0.5 ]]
[[0.5 0.5 ]
[0.5 0.5 ]]
[[0.49999976 0.49999976]
[0.49999976 0.49999976]]
[[0.49999976 0.49999976]
[0.49999976 0.49999976]]
[[0.5 0.5 ]
[0.5 0.5 ]]
[[0.5 0.5 ]
[0.5 0.5 ]]
[[0.49999976 0.49999976]
[0.49999976 0.49999976]]
[[0.49999976 0.49999976]
[0.49999976 0.49999976]]]
dotlike.shape, vr.shape: (8, 2, 2) (8, 2, 5)
so.shape: (8, 2, 5)
so.shape,slogits.shape (16, 5) (16,)
o.shape,o (16, 5) [[1. 1. 1. 1. 1. ]
[1. 1. 1. 1. 1. ]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[1. 1. 1. 1. 1. ]
[1. 1. 1. 1. 1. ]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[1. 1. 1. 1. 1. ]
[1. 1. 1. 1. 1. ]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[1. 1. 1. 1. 1. ]
[1. 1. 1. 1. 1. ]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]
[0.9999995 0.9999995 0.9999995 0.9999995 0.9999995]]
logits.shape, logits (16,) [ 0.6931472 3.6931472 12.693148 27.693148 0.6931472 3.6931472
12.693148 27.693148 0.6931472 3.6931472 12.693148 27.693148
0.6931472 3.6931472 12.693148 27.693148 ]
out
[[1. 1. 1. 1. 1. ]
[1. 1. 1. 1. 1. ]
[0.99999905 0.99999905 0.99999905 0.99999905 0.99999905]
[0.99999905 0.99999905 0.99999905 0.99999905 0.99999905]
[1. 1. 1. 1. 1. ]
[1. 1. 1. 1. 1. ]
[0.99999905 0.99999905 0.99999905 0.99999905 0.99999905]
[0.99999905 0.99999905 0.99999905 0.99999905 0.99999905]]
```

We have now done examples code for most of the operation that are unique to the LSH version of self-attention. I’m sure at this point you are wondering what happens if the number of entries in a bucket is not evenly distributed the way our example is. It is possible, for example for all of the `seqlen`

entries to land in one bucket. Further, since the buckets are not aligned, our ‘chunks’ may be misaligned with the start of the bucket. The implementation addresses this by attending to adjacent chunks as was described in the lecture:

**Figure 12: Misaligned Access, looking before and after**

Hopefully, having implemented parts of this, you will appreciate this diagram more fully.

### 5.5 OurLSHSelfAttention

We can examine the full implementations below. Area’s we did not ‘attend to’ in our implementations above include variable bucket sizes and masking. We will instantiate a layer of the full implementation below. We tried to use the same variable names above to make it easier to decipher the full version. Note that some of the functionality we implemented in our routines is split between `attend`

and `forward_unbatched`

. We’ve inserted our version of hash below, but use the original version of `attend`

.

```
# original version from trax 1.3.4
def attend(
q,=None,
k=None,
v=None,
q_chunk_len=None,
kv_chunk_len=0,
n_chunks_before=0,
n_chunks_after=None,
mask_fn=None,
q_info=None,
kv_info=0.0,
dropout=None,
rng
):"""Dot-product attention, with optional chunking and/or masking.
Args:
q: Query vectors, shape [q_len, d_qk]
k: Key vectors, shape [kv_len, d_qk]; or None
v: Value vectors, shape [kv_len, d_v]
q_chunk_len: Set to non-zero to enable chunking for query vectors
kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
n_chunks_before: Number of adjacent previous chunks to attend to
n_chunks_after: Number of adjacent subsequent chunks to attend to
mask_fn: TODO(kitaev) doc
q_info: Query-associated metadata for masking
kv_info: Key-associated metadata for masking
dropout: Dropout rate
rng: RNG for dropout
Returns:
A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
dots_logsumexp has shape [q_len]. The logsumexp of the attention
probabilities is useful for combining multiple rounds of attention (as in
LSH attention).
"""
assert v is not None
= k is None
share_qk
if q_info is None:
= np.arange(q.shape[-2], dtype=np.int32)
q_info
if kv_info is None and not share_qk:
= np.arange(v.shape[-2], dtype=np.int32)
kv_info
# Split q/k/v into chunks along the time axis, if desired.
if q_chunk_len is not None:
= np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
q = np.reshape(q_info, (-1, q_chunk_len))
q_info
if share_qk:
assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
= q
k = q_chunk_len
kv_chunk_len if kv_info is None:
= q_info
kv_info elif kv_chunk_len is not None:
# kv_info is not None, but reshape as required.
= np.reshape(kv_info, (-1, kv_chunk_len))
kv_info elif kv_chunk_len is not None:
= np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
k = np.reshape(kv_info, (-1, kv_chunk_len))
kv_info
if kv_chunk_len is not None:
= np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))
v
if share_qk:
= length_normalized(k)
k = k / np.sqrt(k.shape[-1])
k
# Optionally include adjacent chunks.
if q_chunk_len is not None or kv_chunk_len is not None:
assert q_chunk_len is not None and kv_chunk_len is not None
else:
assert n_chunks_before == 0 and n_chunks_after == 0
= look_adjacent(k, n_chunks_before, n_chunks_after)
k = look_adjacent(v, n_chunks_before, n_chunks_after)
v = look_adjacent(kv_info, n_chunks_before, n_chunks_after)
kv_info
# Dot-product attention.
= np.matmul(q, np.swapaxes(k, -1, -2))
dots
# Masking
if mask_fn is not None:
= mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])
dots
# Softmax.
= fastmath.logsumexp(dots, axis=-1, keepdims=True)
dots_logsumexp = np.exp(dots - dots_logsumexp)
dots
if dropout > 0.0:
assert rng is not None
# Dropout is broadcast across the bin dimension
= (dots.shape[-2], dots.shape[-1])
dropout_shape #
= tie_in(dots, 1.0 - dropout)
keep_prob = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
keep = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
multiplier = dots * multiplier
dots
# The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
= np.matmul(dots, v)
out = np.reshape(out, (-1, out.shape[-1]))
out = np.reshape(dots_logsumexp, (-1,))
dots_logsumexp return out, dots_logsumexp
```

```
class OurLSHSelfAttention(tl.LSHSelfAttention):
"""Our simplified LSH self-attention """
def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state):
= fastmath.random.split(rng)
attend_rng, output_rng = weights
w_q, w_v, w_o
= np.matmul(x, w_q)
q = np.matmul(x, w_v)
v
if update_state:
= state
_, old_hash_rng = fastmath.random.split(old_hash_rng)
hash_rng, hash_subrng # buckets = self.hash_vectors(q, hash_subrng, mask) # original
## use our version of hash
= our_hash_vectors(
buckets self._n_buckets, self._n_hashes, mask=mask
q, hash_subrng,
)= buckets
s_buckets if self._max_length_for_buckets:
= self._n_hashes * self._max_length_for_buckets
length if buckets.shape[0] < length:
= np.concatenate(
s_buckets - buckets.shape[0], dtype=np.int32)],
[buckets, np.zeros(length =0,
axis
)= (s_buckets, hash_rng)
state else:
= state
buckets, _ if self._max_length_for_buckets:
= buckets[: self._n_hashes * x.shape[0]]
buckets
= x.shape[0]
seqlen assert int(buckets.shape[0]) == self._n_hashes * seqlen
= tie_in(x, np.arange(self._n_hashes * seqlen, dtype=np.int32))
ticker = seqlen * buckets + (ticker % seqlen)
buckets_and_t = fastmath.stop_gradient(buckets_and_t)
buckets_and_t
# Hash-based sort ("s" at the start of variable names means "sorted")
= fastmath.sort_key_val(
sbuckets_and_t, sticker =-1
buckets_and_t, ticker, dimension
)= fastmath.sort_key_val(sticker, ticker, dimension=-1)
_, undo_sort = fastmath.stop_gradient(sbuckets_and_t)
sbuckets_and_t = fastmath.stop_gradient(sticker)
sticker = fastmath.stop_gradient(undo_sort)
undo_sort
= sticker % seqlen
st = np.take(q, st, axis=0)
sq = np.take(v, st, axis=0)
sv
= functools.partial(
mask_fn
mask_self_attention,=self._causal,
causal=True,
exclude_self=self._masked,
masked
)= st
q_info
assert (mask is not None) == self._masked
= None
kv_info if self._masked:
# mask is a boolean array (True means "is valid token")
= np.take(mask, st, axis=0)
smask = tie_in(x, np.ones_like(smask, dtype=np.int32))
ones_like_mask = q_info * np.where(smask, ones_like_mask, -ones_like_mask)
kv_info
## use original version of attend (could use ours but lacks masks and masking)
= attend(
so, slogits
sq,=None,
k=sv,
v=self._chunk_len,
q_chunk_len=self._n_chunks_before,
n_chunks_before=self._n_chunks_after,
n_chunks_after=mask_fn,
mask_fn=q_info,
q_info=kv_info,
kv_info=self._attention_dropout,
dropout=attend_rng,
rng
)
# np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
# also work, but these helpers include performance optimizations for TPU.
= permute_via_gather(so, undo_sort, sticker, axis=0)
o = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)
logits
if self._n_hashes > 1:
= np.reshape(o, (self._n_hashes, seqlen, o.shape[-1]))
o = np.reshape(logits, (self._n_hashes, seqlen, 1))
logits = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
probs = np.sum(o * probs, axis=0)
o
assert o.shape == (seqlen, w_v.shape[-1])
= np.matmul(o, w_o)
out = apply_broadcasted_dropout(out, self._output_dropout, output_rng)
out return out, state
```

```
# Here we're going to try out our LSHSelfAttention
= 3
n_heads = False
causal = False
masked = None
mask = 8
chunk_len = 0
n_chunks_before = 0
n_chunks_after = 0.0
attention_dropout = 5
n_hashes = 4
n_buckets = 8
seq_len = 5
emb_len = OurLSHSelfAttention(
al =n_heads,
n_heads=3,
d_qk=4,
d_v=causal,
causal=8,
chunk_len=n_chunks_before,
n_chunks_before=n_chunks_after,
n_chunks_after=n_hashes,
n_hashes=n_buckets,
n_buckets=True,
use_reference_code=attention_dropout,
attention_dropout="train",
mode
)
= jax.random.uniform(jax.random.PRNGKey(0), (1, seq_len, emb_len), dtype=np.float32)
x = fastmath.random.get_prng(1)
al_osa = al.init(tl.shapes.signature(x), rng=al_osa) _, _
```

` al(x)`

```
DeviceArray([[[ 6.6842824e-01, -1.1364317e-01, -5.4430604e-01,
2.1126242e-01, -1.0988623e-02],
[ 7.0949769e-01, -1.5455186e-01, -5.9923327e-01,
2.2719446e-01, 1.3833597e-02],
[ 7.1442676e-01, -1.2046637e-01, -5.3956550e-01,
1.7320302e-01, -1.6552359e-02],
[ 6.7178923e-01, -7.6611102e-02, -5.9399861e-01,
2.1236290e-01, 7.9482794e-04],
[ 7.1518433e-01, -1.1359167e-01, -5.7821894e-01,
2.1304408e-01, 3.0598283e-02],
[ 6.8235350e-01, -9.3979925e-02, -5.5341840e-01,
2.1608174e-01, -6.6673756e-04],
[ 6.1286640e-01, -8.1027031e-02, -4.8148823e-01,
1.9373316e-01, 3.1555220e-02],
[ 7.2203499e-01, -1.0199663e-01, -5.5215168e-01,
1.7872261e-01, -2.2289157e-02]]], dtype=float32)
```

## 6 Acknowledgements

I’d like to express my thanks to the great Natural Language Processing with Attention Models Course which i completed, and acknowledge the use of some images and other materials from the course in this article.