Building an LSTM Language Model from scratch

In this article we will look at how we build an LSTM language model from scratch that is able to predict the next word in a sequence of words. This covers all the details of how to build the AWD-LSTM architecture.
deep-learning
natural-language-processing
Author

Pranath Fernando

Published

May 31, 2021

1 Introduction

In this article we will look at how we build an LSTM language model that is able to predict the next word in a sequence of words. As part of this, we will also explore several regularization methods. We will build a range of models using basic python & Pytorch to illustrate the fundamentals of this type of model, while also using aspects of the fastai library. We will end up exploring all the different aspects that make up the AWD-LSTM model architecture.

This work is based on material from the fastai deep learning book, chapter 12.

2 Dataset

We will use the fastai curated Human Numbers dataset for this exercise. This is a dataset of the first 10,000 numbers written as words in english.

path = untar_data(URLs.HUMAN_NUMBERS)
Path.BASE_PATH = path
path.ls()
(#2) [Path('valid.txt'),Path('train.txt')]
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
lines
(#9998) ['one \n','two \n','three \n','four \n','five \n','six \n','seven \n','eight \n','nine \n','ten \n'...]
text = ' . '.join([l.strip() for l in lines])
text[:100]
'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'
tokens = text.split(' ')
tokens[:10]
['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']
vocab = L(*tokens).unique()
vocab
(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
nums
(#63095) [0,1,2,1,3,1,4,1,5,1...]

3 Language Model 1 - Linear Neural Network

Lets first try a simple linear model that will aim to predict each word based on the previous 3 words. To do this we can create our input variable as every sequence of 3 words, and our output/target variable as the next word after each sequence of 3.

So in python as tokens and pytorch tensors as numeric values seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3)) seqswe could construct these variables in the following way.

L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))
(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]
seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))
seqs
(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10,  1, 11]), 1),(tensor([ 1, 12,  1]), 13),(tensor([13,  1, 14]), 1),(tensor([ 1, 15,  1]), 16)...]

We can group these into batches using the DataLoader class.

bs = 64
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)

So we will create a linear neural network with 3 layers, and a couple of specific features.

The first feature is to do with using embeddings. The first layer will take the first word embeddings, the second layer the second word embeddings plus the first layer activations, and the third layer the third word embeddings plus the second layer activations. The key observation here is that each word/layer is interpreted in the context of the previous word/layer.

The second feature is that each of these 3 layers will actually be the same layer, that it will have just one weight matrix. Each layer would run into different words even as separate, so really this layer should be able to be repeatedly used to do the same job for each of the 3 words. In other words, while activation values will change as words move through the network, the layer weights will not change from layer to layer.

This way, a layer doesn’t just learn to handle one position i.e. second word position, its forced to generalise and learn to handle all 3 word positions.

class LMModel1(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        
    def forward(self, x):
        h = F.relu(self.h_h(self.i_h(x[:,0])))
        h = h + self.i_h(x[:,1])
        h = F.relu(self.h_h(h))
        h = h + self.i_h(x[:,2])
        h = F.relu(self.h_h(h))
        return self.h_o(h)

So we have 3 key layers:

  • An embedding layer
  • A linear layer to create activations (for next word)
  • A final layer to predict the target 4th word

Lets try training a model built with this architecture.

learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, 
                metrics=accuracy)
learn.fit_one_cycle(4, 1e-3)
epoch train_loss valid_loss accuracy time
0 1.824297 1.970941 0.467554 00:01
1 1.386973 1.823242 0.467554 00:01
2 1.417556 1.654497 0.494414 00:01
3 1.376440 1.650849 0.494414 00:01

So how might we establish a baseline to judge these results? What if we defined a naive predictor that simply predicted the most common word. Lets find the most common word, and then calculate an accuracy when predicting always the most common word.

n,counts = 0,torch.zeros(len(vocab))
for x,y in dls.valid:
    n += y.shape[0]
    for i in range_of(vocab): counts[i] += (y==i).long().sum()
idx = torch.argmax(counts)
idx, vocab[idx.item()], counts[idx].item()/n
(tensor(29), 'thousand', 0.15165200855716662)

4 Language Model 2 - Recurrent Neural Network

So in the forward() method rather than repeating the lines for each layer, we could convert this into a for loop which would not only make our code simplier, but allow us to extend to data that was more than 3 words long and of different lengths.

class LMModel2(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        
    def forward(self, x):
        h = 0
        for i in range(3):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
        return self.h_o(h)
learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, 
                metrics=accuracy)
learn.fit_one_cycle(4, 1e-3)
epoch train_loss valid_loss accuracy time
0 1.816274 1.964143 0.460185 00:01
1 1.423805 1.739964 0.473259 00:01
2 1.430327 1.685172 0.485382 00:01
3 1.388390 1.657033 0.470406 00:01

Note that each time we go through the loop, the resulting activations are passed along to the next loop using the h variable, which is called the hidden state. A recurrent neural network is simply a network that is defined using a loop like this.

5 Language Model 3 - A better RNN

So notice in the latest model we initialise the hidden state to zero with each run through i.e. each batch, this means our batch size greatly effects the amount of information carried over. Also is there a way we can have more ‘signal’? rather than just the 4th word, we could try to predict the others for example.

To not loose our hidden state so frequently and carry over more useful information, we could initialise it outside the forward method. However this now makes our model as deep as the sequence of tokens i.e. 10,000 tokens leads to a 10,000 layer network, which will mean to calculate all the gradients back to the first word/layer could be very time consuming.

So rather than calculate all gradients, we can just keep the last 3 layers. To delete all the gradient history in Pytorch we use the detach() method.

This version of the model now carries over activations between calls to forward(), we could call this kind of model stateful.

class LMModel3(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        self.h = 0
        
    def forward(self, x):
        for i in range(3):
            self.h = self.h + self.i_h(x[:,i])
            self.h = F.relu(self.h_h(self.h))
        out = self.h_o(self.h)
        self.h = self.h.detach()
        return out
    
    def reset(self): self.h = 0

To use this model we need to ensure our data is in the correct order, for example here we are going to divide it into 64 equally sized parts, with each text of size 3.

m = len(seqs)//bs
m,bs,len(seqs)
(328, 64, 21031)
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = L()
    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
    return new_ds

cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(
    group_chunks(seqs[:cut], bs), 
    group_chunks(seqs[cut:], bs), 
    bs=bs, drop_last=True, shuffle=False)

batch = dls.one_batch()
batch[0].size()
torch.Size([64, 3])
learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(10, 3e-3)
epoch train_loss valid_loss accuracy time
0 1.708583 1.873094 0.401202 00:01
1 1.264271 1.781330 0.433173 00:01
2 1.087642 1.535732 0.521875 00:01
3 1.007973 1.578549 0.542308 00:01
4 0.945740 1.660635 0.569231 00:01
5 0.902835 1.605541 0.551923 00:01
6 0.878297 1.527385 0.579087 00:01
7 0.814197 1.451913 0.606250 00:01
8 0.783523 1.509463 0.604087 00:01
9 0.763500 1.511033 0.608413 00:01

6 Language Model 4 - Creating more signal

So with the current model we still predict just one word for every 3 words which limits the amount of signal - what if we predicted the next word after every word?

To do this we need to restructure our data, so that the target variable has the 3 next words after the 3 first words, we can make this a variable sl for sequence length in this case to 16.

sl = 16
seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
         for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
                             group_chunks(seqs[cut:], bs),
                             bs=bs, drop_last=True, shuffle=False)

batch = dls.one_batch()
batch[0].size()
torch.Size([64, 16])
[L(vocab[o] for o in s) for s in seqs[0]]
[(#16) ['one','.','two','.','three','.','four','.','five','.'...],
 (#16) ['.','two','.','three','.','four','.','five','.','six'...]]

Now we can refactor our model to predict the next word after each word rather than after each 3 word sequence.

class LMModel4(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        self.h = 0
        
    def forward(self, x):
        outs = []
        for i in range(sl):
            self.h = self.h + self.i_h(x[:,i])
            self.h = F.relu(self.h_h(self.h))
            outs.append(self.h_o(self.h))
        self.h = self.h.detach()
        return torch.stack(outs, dim=1)
    
    def reset(self): self.h = 0

# Need to reshape output before passing to loss function
def loss_func(inp, targ):
    return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))
learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 3e-3)
epoch train_loss valid_loss accuracy time
0 3.226453 3.039626 0.200033 00:00
1 2.295425 1.925965 0.439697 00:00
2 1.743091 1.818798 0.423258 00:00
3 1.471100 1.779967 0.467285 00:00
4 1.267640 1.823129 0.504883 00:00
5 1.100705 1.991244 0.500814 00:00
6 0.960767 2.086404 0.545085 00:00
7 0.857365 2.240561 0.556803 00:00
8 0.776844 2.004017 0.568766 00:00
9 0.711604 1.991193 0.588949 00:00
10 0.659614 2.064157 0.585775 00:00
11 0.619464 2.033359 0.606283 00:00
12 0.587681 2.100323 0.614176 00:00
13 0.565472 2.145048 0.603760 00:00
14 0.553879 2.149167 0.605550 00:00

Because the task is now harder (predicting after each word) we need to train for longer, but we still do well. Since this is effectively a very deep NN, the results can vary each time because the gradients and vary hugely.

7 Language Model 5 - Multi-layer RNN

While we already in a sense have a multi-layer NN, our repeated part is just once layer still. A deeper RNN gives us more computational power to do better at each loop.

We can use the RNN class to effectively replace the previous class, and allows us to build a new model with multiple stacked RNN’s rather than just the previous one we had.

class LMModel5(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(n_layers, bs, n_hidden)
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(res)
    
    def reset(self): self.h.zero_()
learn = Learner(dls, LMModel5(len(vocab), 64, 2), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 3e-3)
epoch train_loss valid_loss accuracy time
0 3.008033 2.559917 0.449707 00:00
1 2.113339 1.726179 0.471273 00:00
2 1.688941 1.823044 0.389648 00:00
3 1.466082 1.699160 0.462646 00:00
4 1.319908 1.701673 0.516764 00:00
5 1.177464 1.837683 0.543050 00:00
6 1.041084 2.043768 0.554688 00:00
7 0.923601 2.067982 0.549886 00:00
8 0.819859 2.061354 0.562988 00:00
9 0.735049 2.076721 0.568685 00:00
10 0.664878 2.080706 0.570231 00:00
11 0.614425 2.117641 0.586263 00:00
12 0.577034 2.142265 0.588053 00:00
13 0.554870 2.124338 0.591227 00:00
14 0.543019 2.121613 0.590658 00:00

So this model actually did worse than our previous - why? Because we have a deeper model now (just by one extra layer) we probably have exploding and vanishing activations.

Generally having a deeper layered model gives us more compute to get better results, however this also makes it more difficult to train because the compunded activations can explode or vanish - think matrix multiplications!

Researchers have developed 2 approaches to try and rectify this: long short-term memory layers (LSTM’s) and gated reccurent units (GRU’s).

8 Language Model 6 - LSTM’s

LSTM’s were invented by Jürgen Schmidhuber and Sepp Hochreiter in 1997, and they have 2 hidden states.

In our previous RNN we have one hidden state ‘h’ that does 2 things:

  • Holds signal to help predict the next word
  • Holds signal of all previous words

These are potentially very different things to remember together in one value, and in practice RRN’s are not very good at retaining the second long term information. LSTM’s have a second hidden state called a cell state specifically to focus on this second requirement as a kind of long short-term memory.

Lets look at the architecture of a LSTM.

So the inputs come in from the left which are:

  • Xt: input
  • ht-1: previous hidden state
  • ct-1: previous cell state

The 4 orange boxes are layers with either sigmoid or tanh activation functions. The green circles are element-wise operations. The outputs on the right are:

  • ht: new hidden state
  • ct: new cell state

Which will be used at the next input. The 4 orange layers are called gates. Note also how little the cell state at the top is changed, this is what allows it to better persist over time.

8.1 The 4 Gates of an LSTM

  1. Forget gate
  2. Input gate
  3. Cell gate
  4. Output gate

The first gate the forget gate, is a linear layer followed by a sigmoid, gives the LSTM the ability to forget things about its long term state held in the cell state. For example, when the input is a xxbos token, we might expect the LTSM will learn to trigger this to reset its cell state.

The second and third gates work together to update/add to the cell state. The input gate decides which parts of the cell state to update, and the cell gate decides what those updated values should be.

The output gate decides what information from the cell state is used to generate the output.

We can define this as the following class.

class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate  = nn.Linear(ni + nh, nh)
        self.cell_gate   = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)

    def forward(self, input, state):
        h,c = state
        h = torch.cat([h, input], dim=1)
        forget = torch.sigmoid(self.forget_gate(h))
        c = c * forget
        inp = torch.sigmoid(self.input_gate(h))
        cell = torch.tanh(self.cell_gate(h))
        c = c + inp * cell
        out = torch.sigmoid(self.output_gate(h))
        h = out * torch.tanh(c)
        return h, (h,c)

We can refactor the code to make this more efficient, in particular creating just one big matrix multiplication rather than 4 smaller ones.

class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.ih = nn.Linear(ni,4*nh)
        self.hh = nn.Linear(nh,4*nh)

    def forward(self, input, state):
        h,c = state
        # One big multiplication for all the gates is better than 4 smaller ones
        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
        ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])
        cellgate = gates[3].tanh()

        c = (forgetgate*c) + (ingate*cellgate)
        h = outgate * c.tanh()
        return h, (h,c)

The Pytorch chunk method helps us split our tensor into 4 parts.

t = torch.arange(0,10); t
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
t.chunk(2)
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))

Here we will define a 2 layer LSTM which is the same network as model 5. We can actually train this at a higher learning rate for less time and do better, as this network should be more stable and easier to train.

class LMModel6(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(res)
    
    def reset(self): 
        for h in self.h: h.zero_()
learn = Learner(dls, LMModel6(len(vocab), 64, 2), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(15, 1e-2)
epoch train_loss valid_loss accuracy time
0 3.007779 2.770814 0.284017 00:01
1 2.204949 1.782870 0.425944 00:01
2 1.606196 1.831585 0.462402 00:01
3 1.296969 1.999463 0.479411 00:01
4 1.080299 1.889699 0.553141 00:01
5 0.828938 1.813550 0.593262 00:01
6 0.623377 1.710710 0.662516 00:01
7 0.479048 1.723749 0.687663 00:01
8 0.350940 1.458227 0.718913 00:01
9 0.260764 1.484386 0.732096 00:01
10 0.201649 1.384711 0.752523 00:01
11 0.158970 1.384149 0.753011 00:01
12 0.132954 1.377875 0.750244 00:01
13 0.117867 1.367185 0.756104 00:01
14 0.109761 1.366078 0.756104 00:01

9 Language Model 7 - Weight-Tied Regularized LSTM’s

While this new LSTM model did much better, we can see it’s overfitting to the training data i.e. notice how while the training loss is going down, the validation loss does not really improve so the model is not generalising well. Dropout can be a regularization method that we can use here to try to prevent overfitting. And architecture that uses dropout as well as an LSTM is called an AWD-LSTM.

Activation regularization (AR) and temporal activation regularization (TAR) are two regularization methods very similar to weight decay.

To regularize the final activations these need to be stored, then we add the means of the squares of them to the loss (times a factor alpha for control).

loss += alpha * activations.pow(2).mean()

TAR is connected to the sequential nature of text i.e. that that outputs of LSTM’s should make sense when in order. TAR encourages this by penalising large differences between consequtive activations so to encourage them to be as small as possible.

loss += beta * (activations[:,1:] - activations[:,:-1]).pow(2).mean()

AR is usually applied to dropped out activations (to not penalise activations zeroed) while TAR is applied to non-dropped out activations for the opposite reasons. The RNNRegularizer callback will apply both of these.

With Weight tying we make use of a symmeterical aspect of embeddings in this model. At the start of the model the embedding layer converts words to embedding numbers, at the end of the model we map the final layer to words. We might expect these could be very similar mappings if not the same, so we can explictly encourage this by actually making the weights the same for this first and final layers/embeddings.

self.h_o.weight = self.i_h.weight

So we can combine dropout with AR & TAR and weight tying to train our LSTM.

class LMModel7(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers, p):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.drop = nn.Dropout(p)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h_o.weight = self.i_h.weight
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        raw,h = self.rnn(self.i_h(x), self.h)
        out = self.drop(raw)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(out),raw,out
    
    def reset(self): 
        for h in self.h: h.zero_()

# Create regularized learner using RNNRegularizer
learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),
                loss_func=CrossEntropyLossFlat(), metrics=accuracy,
                cbs=[ModelResetter, RNNRegularizer(alpha=2, beta=1)])

# This is the equivilent as the TextLearner automatically adds these callbacks
learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),
                    loss_func=CrossEntropyLossFlat(), metrics=accuracy)

# Train the model and add extra regularization with weight decay
learn.fit_one_cycle(15, 1e-2, wd=0.1)
epoch train_loss valid_loss accuracy time
0 2.513700 1.898873 0.498942 00:01
1 1.559825 1.421029 0.651937 00:01
2 0.810041 1.324630 0.703695 00:01
3 0.406249 0.870849 0.801514 00:01
4 0.211201 1.012451 0.776774 00:01
5 0.117430 0.748297 0.827474 00:01
6 0.072397 0.652809 0.843587 00:01
7 0.050372 0.740491 0.826172 00:01
8 0.037560 0.796995 0.831462 00:01
9 0.028582 0.669326 0.850830 00:01
10 0.022323 0.614551 0.855632 00:01
11 0.018281 0.670560 0.858317 00:01
12 0.014915 0.645430 0.856771 00:01
13 0.012732 0.656426 0.855387 00:01
14 0.011765 0.683027 0.853271 00:01

10 Conclusion

In this article we have examined how we build an LSTM language model, in particular the AWD-LSTM architecture, which also makes use of several regularization techniques.

Subscribe