# Assignment 2: Transformer Summarizer

Welcome to the second assignment of course 4. In this assignment you will explore summarization using the transformer model. Yes, you will implement the transformer decoder from scratch, but we will slowly walk you through it. There are many hints in this notebook so feel free to use them as needed.

## Outline

### Introduction

Summarization is an important task in natural language processing and could be useful for a consumer enterprise. For example, bots can be used to scrape articles, summarize them, and then you can use sentiment analysis to identify the sentiment about certain stocks. Anyways who wants to read an article or a long email today, when you can build a transformer to summarize text for you. Let’s get started, by completing this assignment you will learn to:

• Use built-in functions to preprocess your data
• Implement DotProductAttention
• Implement Causal Attention
• Understand how attention works
• Build the transformer model
• Summarize an article

As you can tell, this model is slightly different than the ones you have already implemented. This is heavily based on attention and does not rely on sequences, which allows for parallel computing.

## Part 1: Importing the dataset

Trax makes it easy to work with Tensorflow’s datasets:

## 1.1 Tokenize & Detokenize helper functions

Just like in the previous assignment, the cell above loads in the encoder for you. Given any data set, you have to be able to map words to their indices, and indices to their words. The inputs and outputs to your Trax models are usually tensors of numbers where each number corresponds to a word. If you were to process your data manually, you would have to make use of the following:

• word2Ind: a dictionary mapping the word to its index.
• ind2Word: a dictionary mapping the index to its word.
• word2Count: a dictionary mapping the word to the number of times it appears.
• num_words: total number of words that have appeared.

Since you have already implemented these in previous assignments of the specialization, we will provide you with helper functions that will do this for you. Run the cell below to get the following functions:

• tokenize: converts a text sentence to its corresponding token list (i.e. list of indices). Also converts words to subwords.
• detokenize: converts a token list to its corresponding sentence (i.e. string).

## 1.2 Preprocessing for Language Models: Concatenate It!

This week you will use a language model — Transformer Decoder — to solve
an input-output problem. As you know, language models only predict the next
word, they have no notion of inputs. To create a single input suitable for
a language model, we concatenate inputs with targets putting a separator
in between. We also need to create a mask — with 0s at inputs and 1s at targets — so that the model is not penalized for mis-predicting the article and only focuses on the summary. See the preprocess function below for how this is done.

Single example mask:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

Single example:

By . Margot Peppers . Nigerian and Cameroonian pop star Dencia has hit
out at Lupita Nyong'o for her new contract with Lancome, accusing her
of bowing to 'white people companies'. In an angry tweet directed at
the 12 Years A Slave star, she wrote: 'Oh @Lupita_Nyongo cln't talk
abt the bleaching creams white people (Companies) make cuz the white
man pays her, they own her!! [sic]'. The comment comes just a month
after Miss Nyong'o mentioned Dencia - who has been accused of
marketing her own brand of skin-bleaching cream called Whitenicious -
in a speech about learning to value the color of her own skin. Scroll
down for video . Butting heads: Nigerian and Cameroonian pop star
Dencia has hit out at Lupita Nyong'o for her new contract with
Lancome, accusing her of bowing to 'white people companies' Fighting
words: In a tweet directed at the 12 Years A Slave star, she wrote:
'Oh @Lupita_Nyongo cln't talk abt the bleaching creams white people
(Companies) make cuz the white man pays her, they own her!! [sic]' The
pop star is no stranger to . controversy; in a February interview with
Ebony, she all but admitted . that Whitenicious is intended as a skin-
lightener, not as a cure for . dark spots as it claims. 'When . you
take that picture and you put a picture of Dencia darker, this is .
what you're telling people - the product really works,' she said. 'And
guess what? People really want to buy it. It's what it is. I don't
really care.' Given her defiant and hypocritical attitude, it's no
surprise the fiery singer was angered when Miss Nyong'o called her out
in a speech at Essence's Black Women in Hollywood event on February
27. Influential: In a recent speech, Miss Nyong'o read out loud a
letter from a fan who said she decided not to buy Dencia's skin-
whitening cream Whitenicious because the actress had inspired her to
love her own skin . On-screen: Miss Nyong'o won an Oscar for Best
Supporting Actress for her role in 2013 film 12 Years A Slave . In her
talk, the 30-year-old opened up about how conventional standards of
beauty once affected her self-esteem, reading aloud a letter written
to her by a young girl who viewed her as a role model. 'Dear Lupita,'
reads the letter. 'I think you're really lucky to be this black but
yet this successful in Hollywood overnight. I was just about to buy
Dencia's Whitenicious cream to lighten my skin when you appeared on
the world map and saved me.' 'My heart bled a little when I read those
words,' the actress said through tears, explaining how as a child,
she, too, would pray that she'd one day wake up with lighter skin.
Hypocritical: Dencia is no stranger to controversy; in a February
interview with Ebony, she essentially admitted that Whitenicious is
intended as a skin-lightener, not as a cure for dark spots as it
claims . Perpetuating the problem: 'When you take that picture and you
put a picture of Dencia darker, this is what you're telling people -
the product really works,' she said. 'And guess what? People really
want to buy it' But while the actress saw the letter as a source of
inspiration, Dencia took it as a personal attack. After her angry
tweet at Miss Nyong'o, criticism poured in, with one person tweeting:
'B**** lupita is the new face of Lancôme!! SHE WINS!! And you're just
TRASH [sic]'. In her response, Dencia said of the cosmetics company:
'But they sell bleaching cream tho [sic]'. The pop star is likely
referring to Lancome's Blanc Expert range of cosmetics, which are
actually advertised as 'brighteners' that 'regulate melanin production
and awaken the luminosity of the skin'. And as far as Dencia's claim
that Lancome is a 'white people company', a quick perusal of the
website reveals that it has a number of concealers and foundations in
darker skin tones.<EOS><pad>Dencia's comment is hypocritical
considering she recently courted controversy for marketing 'dark spot
remover' Whitenicious, which is frequently used as a skin-whitening
cream .<EOS>


## 1.3 Batching with bucketing

As in the previous week, we use bucketing to create batches of data.

(2, 1024)

[   27  1091    23    46  3873  1248 16013   256 11599 23297   102    68
24308     7     5  1037  1958   320  1477   105  2557   186  4133    28
18175  1348  1287     3  4927  7577    28  8478 10120 19134  7951   364
7317  4990    79     2   393     2   186  8962  2995  9813  4476  3632
2270     5     2   705     2   721 10731    16   186 17136    16   193
54   102    41  1459   320    31 16946    47     2   119  3770   278
355    28   622   263    78  2613     3   312  4543     4  8662  3788
3632  2270     5     6  3048 23524     2  1210     2  1958   320  2033
105    61     2 19134  7951   364  7317  4990    79 24810    17   213
1091     2   931   320   213 16946    47   415 20579 20964    58  1782
863   213  7726     2   213  7599  3938  4133    28 26719     4   752
1480  2868   132    68   583  3898 20579 20964    58   240   197     3
4531  9531  2959   127   132    28 27439  9275  1628  1602     3  8406
5364    11  4927  7577    28  8478 10120 19134  7951   364  7317  4990
79     2   393     2   497     2   186    68 24308  8962  2995  9813
4476  3632  2270     5     2   705     2   721 10731    16   186 17136
16   193    54   809    31   278    78  2613  7511    15  1037 20274
21   379 21549  7150    11  9813  4476  3632  2270     5    80 18649
1496   667   213 17136    45    78    15   882  1838   213  2439  7883
379    27  1147     6   104     6   292   966    43 11850   213  1621
2   931   320   213 13021     4     2    35    22   206    19  5632
213  1018   111   213  2948   186   213 25931     4     3  2713  7801
320    28  6105    32   922  1838   213  6350   141   102 24114    75
78  2613   186  7511    41  2362     2    41   233  3632  2270     5
6  3048 23524  1955    78    28 11261  1797  1782   198    25    92
3787  3103   527 13747   320   213  7599     2   487   159   213   669
27884     4  1622 27872   391  5977  3103   527  2918   186  1472   320
18    46   810   132    28  2439  7726  3898   213 13021     4   127
3    34    31 18649  3347     2   148 19134  7951   364  7317  4990
79   186  9813  4476  3632  2270     5    18 17136    45    78    31
5369   186 19175     5     3     9  2789    25 11203     2   412    25
213   966   186    54  1697     3 12849    14    11  7317  4990    79
12365   146 24810    17   213  1091  4617 27439  9275  1628  4543     4
8662  3788  3632  2270     5     6  3048 23524     2   186   131  4133
28 26719     4  6901   809   213    60     6  1797  6350   809   213
414     8 12370    21    12   186   710   171   864  2362   809   213
1610   379     9 16946    47   415  3357 15581    81     7     5  1431
1890   163  4336  7188    20    78  3632  2270     5     6  3048 23524
7     5   661     2    35   646    25 17926 25290 16741    20  4140
2   213 13021     4   127     3 19134  7951   364  7317  4990    79
1353  3873  1248 11599 23297     2 17260  8041   893   213  5627   527
28   966     2  2439  1740  1524  7726   186 23638    16 24668 21273
204     2   931   320  1882     3  4531  9531  2959   127 19134  7951
364  7317  4990    79    43  9363     4   760    70    35    62    19
2851  2754   103  1353    70  1480 22646   272  7304   132    28  1501
809   213  1881  1610     3   305  1353   475   809   213 16946    47
415  7411    84    78   281  3997    88   226 20934     4     3  9813
4476  3632  2270     5  1353    43  3873  1248   966 17260  8041 16704
464   186  2439  1740  1524  7726   186  1233   320   213  1156 10835
78   281  1696    88   226 20934     4     3    27   924  3729    23
46  4648  1019  3112  1859   809 18235  5333  9141 25733   812    10
1     0  4927  7577    28  8478 10120 19134  7951   364  7317  4990
79     2   393     2   186  8962  2995  9813  4476  3632  2270     5
2   705     2   721  2557   102  4925  1838    28   622    78  2613
1859 16346 27439  6774  1628   312    15  1037     2  4543     4  8662
3788  3632  2270     5     6  3048 23524     2  1210     2  1958   320
1477   105     2 19134  7951   364  7317  4990    79 12365 24810    17
68 16346 27439  6774  1628   305  4133 26719     4  6901   186   864
2362   320   423    68  1955 16346 27439  6774  1628    27  1147     6
104     6   292  2635 11850   213  1621  2104     1     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0     0     0     0     0     0     0     0     0
0     0     0     0]


Things to notice:

• First we see the corresponding values of the words.
• The first 1, which represents the <EOS> tag of the article.
• Followed by a 0, which represents a <pad> tag.
• After the first 0 (<pad> tag) the corresponding values are of the words that are used for the summary of the article.
• The second 1 represents the <EOS> tag for the summary.
• All the trailing 0s represent <pad> tags which are appended to maintain consistent length (If you don’t see them then it would mean it is already of max length)
Article:

A woman has been charged with reckless manslaughter after her
boyfriend's mother tried to stop them fighting and suffered a fatal
heart attack. Claudia Yanira Hernandez Soriano, 25, and Juan Francisco
Martinez Rojas, 28, started punching and scratching each other after
they returned to their Bergen, New Jersey home following a party early
on Monday. When Ana Angelina Rojas-Jovel, 45, tried to break them up,
Hernandez Soriano assaulted the woman, according to the Bergen County
Prosecutor. 'During the assault, the victim apparently suffered a
cardiac event which resulted in her death,' Prosecutor John L.
Molinelli said in a statement. Fight: Claudia Yanira Hernandez
Soriano, 25, above, and her boyfriend Juan Francisco Martinez Rojas,
28, started punching and scratching each other at their home on Monday
when his mother intervened . Injured: Martinez Rojas' booking shot
shows the scratches on his face from the domestic dispute . A seven-
year-old child also witnessed the fight, according to the prosecutor,
but he did not reveal the relationship between the adults and the
youngster. Police responded to a 911 call from the apartment just
after 4am on Monday and when they arrived, they found Rojas-Jovel dead
on a bedroom floor. 'There were no obvious signs of trauma to the
victim, however... the [couple] displayed signs of injury and appeared
to have been involved in a domestic assault,' the prosecutor said. In
their booking photos, both Hernandez Soriano and Martinez Rojas have
scratches on their faces and necks. The pair were interviewed, as were
the child and other residents. Scene: Soriano allegedly then assaulted
the woman, Ana Angelina Rojas-Jovel, and she suffered a cardiac arrest
at the first-floor apartment at the house (pictured) and died before
police arrived at the scene . The Bergen County Medical Examiner's
Office conducted an autopsy on Rojas-Jovel's body, but results were
pending toxicology tests, the prosecutor said. Hernandez Soriano was
charged with manslaughter, endangering the welfare of a child,
domestic violence simple assault and hindering apprehension, according
to authorities. Molinelli said Hernandez Soriano also hid evidence -
but would not detail what it was - which investigators later recovered
in a search at the crime scene. She was held at the Bergen County Jail
on $250,000 bail. Martinez Rojas was also charged with child endangerment and domestic violence simple assault and sent to the county jail on$75,000 bail. A court hearing has been scheduled for
Thursday morning at Hackensack Superior Court.<EOS><pad>ClaudiaYanira
Hernandez Soriano, 25, and Juan Francisco Martinez Rojas, 28, started
fighting after returning from a party on Monday morning . When his
mother, Ana Angelina Rojas-Jovel, 45, tried to stop them, Hernandez
Soriano allegedly assaulted her . She suffered cardiac arrest and
police arrived to find her dead . A seven-year-old girl witnessed the


You can see that the data has the following structure:

• [Article] -> <EOS> -> <pad> -> [Article Summary] -> <EOS> -> (possibly) multiple <pad>

The loss is taken only on the summary using cross_entropy as loss function.

# Part 2: Summarization with transformer

Now that we have given you the data generator and have handled the preprocessing for you, it is time for you to build your own model. We saved you some time because we know you have already preprocessed data before in this specialization, so we would rather you spend your time doing the next steps.

You will be implementing the attention from scratch and then using it in your transformer model. Concretely, you will understand how attention works, how you use it to connect the encoder and the decoder.

## 2.1 Dot product attention

Now you will implement dot product attention which takes in a query, key, value, and a mask. It returns the output.

Here are some helper functions that will help you create tensors and display useful information:

• create_tensor creates a jax numpy array from a list of lists.
• display_tensor prints out the shape and the actual tensor.

Before implementing it yourself, you can play around with a toy example of dot product attention without the softmax operation. Technically it would not be dot product attention without the softmax but this is done to avoid giving away too much of the answer and the idea is to display these tensors to give you a sense of how they look like.

The formula for attention is this one:

$d_{k}$ stands for the dimension of queries and keys.

The query, key, value and mask vectors are provided for this example.

Notice that the masking is done using very negative values that will yield a similar effect to using $-\infty$.

query shape: (2, 3)

[[1 0 0]
[0 1 0]]

key shape: (2, 3)

[[1 2 3]
[4 5 6]]

value shape: (2, 3)

[[0 1 0]
[1 0 1]]

[[ 0.e+00  0.e+00]
[-1.e+09  0.e+00]]


Expected Output:

query dot key shape: (2, 2)

[[0.57735026 2.309401  ]
[1.1547005  2.8867514 ]]


Expected Output:

masked query dot key shape: (2, 2)

[[ 5.7735026e-01  2.3094010e+00]
[-1.0000000e+09  2.8867514e+00]]


Expected Output:

masked query dot key dot value shape: (2, 3)

[[ 2.3094010e+00  5.7735026e-01  2.3094010e+00]
[ 2.8867514e+00 -1.0000000e+09  2.8867514e+00]]


Expected Output:

In order to use the previous dummy tensors to test some of the graded functions, a batch dimension should be added to them so they mimic the shape of real-life examples. The mask is also replaced by a version of it that resembles the one that is used by trax:

query with batch dim shape: (1, 2, 3)

[[[1 0 0]
[0 1 0]]]

key with batch dim shape: (1, 2, 3)

[[[1 2 3]
[4 5 6]]]

value with batch dim shape: (1, 2, 3)

[[[0 1 0]
[1 0 1]]]

[[ True  True]
[False  True]]


Expected Output:

### Exercise 01

Instructions: Implement the dot product attention. Concretely, implement the following equation

$Q$ - query,
$K$ - key,
$V$ - values,
$M$ - mask,
${d_k}$ - depth/dimension of the queries and keys (used for scaling down)

You can implement this formula either by trax numpy (trax.math.numpy) or regular numpy but it is recommended to use jnp.

Something to take into consideration is that within trax, the masks are tensors of True/False values not 0’s and $-\infty$ as in the previous example. Within the graded function don’t think of applying the mask by summing up matrices, instead use jnp.where() and treat the mask as a tensor of boolean values with False for values that need to be masked and True for the ones that don’t.

Also take into account that the real tensors are far more complex than the toy ones you just played with. Because of this avoid using shortened operations such as @ for dot product or .T for transposing. Use jnp.matmul() and jnp.swapaxes() instead.

This is the self-attention block for the transformer decoder. Good luck!

DeviceArray([[[0.8496746 , 0.15032545, 0.8496746 ],
[1.        , 0.        , 1.        ]]], dtype=float32)


Expected Output:

query matrix (2D tensor) shape: (2, 3)

[[1 0 0]
[0 1 0]]

batch of two (multi-head) collections of query matrices (4D tensor) shape: (2, 2, 2, 3)

[[[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]]

[[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]]]

one batch of concatenated heads of query matrices (3d tensor) shape: (1, 2, 6)

[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]

three batches of concatenated heads of query matrices (3d tensor) shape: (3, 2, 6)

[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]

[[1 0 0 1 0 0]
[0 1 0 0 1 0]]

[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]


It is important to know that the following 3 functions would normally be defined within the CausalAttention function further below.

However this makes these functions harder to test. Because of this, these functions are shown individually using a closure (when necessary) that simulates them being inside of the CausalAttention function. This is done because they rely on some variables that can be accessed from within CausalAttention.

### Support Functions

compute_attention_heads : Gets an input $x$ of dimension (batch_size, seqlen, n_heads $\times$ d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size $\times$ n_heads, seqlen, d_head).

For the closures you only have to fill the inner function.

input tensor shape: (3, 2, 6)

[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]

[[1 0 0 1 0 0]
[0 1 0 0 1 0]]

[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]

output tensor shape: (6, 2, 3)

[[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]]


Expected Output:

dot_product_self_attention : Creates a mask matrix with False values above the diagonal and True values below and calls DotProductAttention which implements dot product self attention.

DeviceArray([[[0.        , 1.        , 0.        ],
[0.8496746 , 0.15032543, 0.8496746 ]]], dtype=float32)


Expected Output:

compute_attention_output : Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (batch_size, seqlen, n_heads $\times$ d_head). These operations concatenate (stack/merge) the heads.

input tensor shape: (6, 2, 3)

[[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]

[[1 0 0]
[0 1 0]]]

output tensor shape: (3, 2, 6)

[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]

[[1 0 0 1 0 0]
[0 1 0 0 1 0]]

[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]


Expected Output:

### Causal Attention Function

Now it is time for you to put everything together within the CausalAttention or Masked multi-head attention function:

Instructions: Implement the causal attention.
Your model returns the causal attention through a $tl.Serial$ with the following:

• : consisting of 3 [tl.Dense(d_feature), ComputeAttentionHeads] to account for the queries, keys, and values.
• : Takes in dot_product_self_attention function and uses it to compute the dot product using $Q$, $K$, $V$.
• : Takes in compute_attention_output_closure to allow for parallel computing.
• : Final Dense layer, with dimension d_feature.

Remember that in order for trax to properly handle the functions you just defined, they need to be added as layers using the tl.Fn() function.

Serial[
Branch_out3[
]
DotProductAttn_in3
AttnOutput
Dense_512
]


Expected Output:

## 2.3 Transformer decoder block

Now that you have implemented the causal part of the transformer, you will implement the transformer decoder block. Concretely you will be implementing this image now.

To implement this function, you will have to call the CausalAttention or Masked multi-head attention function you implemented above. You will have to add a feedforward which consists of:

• : used to layer normalize
• : the dense layer
• : feed forward activation (we use ReLu) here.
• : dropout layer
• : dense layer
• : dropout layer

Finally once you implement the feedforward, you can go ahead and implement the entire block using:

• : takes in the tl.LayerNorm(), causal attention block, tl.dropout.

• : takes in the feedforward block you will implement.

### Exercise 03

Instructions: Implement the transformer decoder block. Good luck!

[Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
], Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Relu
Dropout
Dense_512
Dropout
]
]
]]


Expected Output:

## 2.4 Transformer Language Model

You will now bring it all together. In this part you will use all the subcomponents you previously built to make the final model. Concretely, here is the image you will be implementing.

### Exercise 04

Instructions: Previously you coded the decoder block. Now you will code the transformer language model. Here is what you will need.

• positional_enconder - a list containing the following layers:

• A list of n_layers decoder blocks.

• takes in the following layers or lists of layers:
• : shift the tensor to the right by padding on axis 1.
• positional_encoder : encodes the text positions.
• decoder_blocks : the ones you created.
• : a layer norm.
• : takes in the vocab_size.
• : to predict.

Go go go!! You can do it :)

Serial[
ShiftRight(1)
Embedding_33300_512
Dropout
PositionalEncoding
Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
]
Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Relu
Dropout
Dense_512
Dropout
]
]
]
LayerNorm
Dense_33300
LogSoftmax
]


Expected Output:

# Part 3: Training

Now you are going to train your model. As usual, you have to define the cost function, the optimizer, and decide whether you will be training it on a gpu or cpu. In this case, you will train your model on a cpu for a few steps and we will load in a pre-trained model that you can use to predict with your own words.

### 3.1 Training the model

You will now write a function that takes in your model and trains it. To train your model you have to decide how many times you want to iterate over the entire data set. Each iteration is defined as an epoch. For each epoch, you have to go over all the data, using your training iterator.

### Exercise 05

Instructions: Implement the train_model program below to train the neural network above. Here is a list of things you should do:

You will be using a cross entropy loss, with Adam optimizer. Please read the Trax documentation to get a full understanding.

The training loop that this function returns can be runned using the run() method by passing in the desired number of steps.

Notice that the model will be trained for only 10 steps.

Even with this constraint the model with the original default arguments took a very long time to finish. Because of this some parameters are changed when defining the model that is fed into the training loop in the function above.

Step      1: Ran 1 train steps in 9.11 secs
Step      1: train CrossEntropyLoss |  10.41297626
Step      1: eval  CrossEntropyLoss |  10.41586781
Step      1: eval          Accuracy |  0.00000000

Step     10: Ran 9 train steps in 58.21 secs
Step     10: train CrossEntropyLoss |  10.41278458
Step     10: eval  CrossEntropyLoss |  10.41440201
Step     10: eval          Accuracy |  0.00000000


# Part 4: Evaluation

In this part you will evaluate by loading in an almost exact version of the model you coded, but we trained it for you to save you time. Please run the cell below to load in the model.

As you may have already noticed the model that you trained and the pretrained model share the same overall architecture but they have different values for some of the parameters:

Original (pretrained) model:

TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8,
dropout=0.1, max_len=4096, ff_activation=tl.Relu)


Your model:

TransformerLM(d_model=4, d_ff=16, n_layers=1, n_heads=2)


Only the parameters shown for your model were changed. The others stayed the same.

# Part 5: Testing with your own input

You will now test your input. You are going to implement greedy decoding. This consists of two functions. The first one allows you to identify the next symbol. It gets the argmax of the output of your model and then returns that index.

### Exercise 06

Instructions: Implement the next symbol function that takes in the cur_output_tokens and the trained model to return the index of the next word.

'The'


Expected Output:

### 5.1 Greedy decoding

Now you will implement the greedy_decode algorithm that will call the next_symbol function. It takes in the input_sentence, the trained model and returns the decoded sentence.

### Exercise 07

Instructions: Implement the greedy_decode algorithm.

It was a sunny day when I went to the market to buy some flowers. But
I only found roses, not tulips.

:
: I
: I just
: I just found
: I just found ros
: I just found roses
: I just found roses,
: I just found roses, not
: I just found roses, not tu
: I just found roses, not tulips
: I just found roses, not tulips
: I just found roses, not tulips.
: I just found roses, not tulips.<EOS>
: I just found roses, not tulips.<EOS>


Expected Output:

Expected Output:

Congratulations on finishing this week’s assignment! You did a lot of work and now you should have a better understanding of the encoder part of Transformers and how Transformers can be used for text summarization.

Keep it up!