Assignment 2: Transformer Summarizer
Welcome to the second assignment of course 4. In this assignment you will explore summarization using the transformer model. Yes, you will implement the transformer decoder from scratch, but we will slowly walk you through it. There are many hints in this notebook so feel free to use them as needed.
Outline
- Introduction
- Part 1: Importing the dataset
- Part 2: Summarization with transformer
- Part 3: Training
- Part 4: Evaluation
- Part 5: Testing with your own input
Summarization is an important task in natural language processing and could be useful for a consumer enterprise. For example, bots can be used to scrape articles, summarize them, and then you can use sentiment analysis to identify the sentiment about certain stocks. Anyways who wants to read an article or a long email today, when you can build a transformer to summarize text for you. Let's get started, by completing this assignment you will learn to:
- Use built-in functions to preprocess your data
- Implement DotProductAttention
- Implement Causal Attention
- Understand how attention works
- Build the transformer model
- Evaluate your model
- Summarize an article
As you can tell, this model is slightly different than the ones you have already implemented. This is heavily based on attention and does not rely on sequences, which allows for parallel computing.
1 | import sys |
## Part 1: Importing the dataset
Trax makes it easy to work with Tensorflow's datasets:
1 | # This will download the dataset if no data_dir is specified. |
## 1.1 Tokenize & Detokenize helper functions
Just like in the previous assignment, the cell above loads in the encoder for you. Given any data set, you have to be able to map words to their indices, and indices to their words. The inputs and outputs to your Trax models are usually tensors of numbers where each number corresponds to a word. If you were to process your data manually, you would have to make use of the following:
- word2Ind: a dictionary mapping the word to its index.
- ind2Word: a dictionary mapping the index to its word.
- word2Count: a dictionary mapping the word to the number of times it appears.
- num_words: total number of words that have appeared.
Since you have already implemented these in previous assignments of the specialization, we will provide you with helper functions that will do this for you. Run the cell below to get the following functions:
- tokenize: converts a text sentence to its corresponding token list (i.e. list of indices). Also converts words to subwords.
- detokenize: converts a token list to its corresponding sentence (i.e. string).
1 | def tokenize(input_str, EOS=1): |
1.2 Preprocessing for Language Models: Concatenate It!
This week you will use a language model -- Transformer Decoder -- to solve an input-output problem. As you know, language models only predict the next word, they have no notion of inputs. To create a single input suitable for a language model, we concatenate inputs with targets putting a separator in between. We also need to create a mask -- with 0s at inputs and 1s at targets -- so that the model is not penalized for mis-predicting the article and only focuses on the summary. See the preprocess function below for how this is done.
1 | # Special tokens |
1 | # prints mask, 0s on article, 1s on summary |
Single example mask:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
1 | # prints: [Example][<EOS>][<pad>][Example Summary][<EOS>] |
Single example:
By . Margot Peppers . Nigerian and Cameroonian pop star Dencia has hit
out at Lupita Nyong'o for her new contract with Lancome, accusing her
of bowing to 'white people companies'. In an angry tweet directed at
the 12 Years A Slave star, she wrote: 'Oh @Lupita_Nyongo cln't talk
abt the bleaching creams white people (Companies) make cuz the white
man pays her, they own her!! [sic]'. The comment comes just a month
after Miss Nyong'o mentioned Dencia - who has been accused of
marketing her own brand of skin-bleaching cream called Whitenicious -
in a speech about learning to value the color of her own skin. Scroll
down for video . Butting heads: Nigerian and Cameroonian pop star
Dencia has hit out at Lupita Nyong'o for her new contract with
Lancome, accusing her of bowing to 'white people companies' Fighting
words: In a tweet directed at the 12 Years A Slave star, she wrote:
'Oh @Lupita_Nyongo cln't talk abt the bleaching creams white people
(Companies) make cuz the white man pays her, they own her!! [sic]' The
pop star is no stranger to . controversy; in a February interview with
Ebony, she all but admitted . that Whitenicious is intended as a skin-
lightener, not as a cure for . dark spots as it claims. 'When . you
take that picture and you put a picture of Dencia darker, this is .
what you're telling people - the product really works,' she said. 'And
guess what? People really want to buy it. It's what it is. I don't
really care.' Given her defiant and hypocritical attitude, it's no
surprise the fiery singer was angered when Miss Nyong'o called her out
in a speech at Essence's Black Women in Hollywood event on February
27. Influential: In a recent speech, Miss Nyong'o read out loud a
letter from a fan who said she decided not to buy Dencia's skin-
whitening cream Whitenicious because the actress had inspired her to
love her own skin . On-screen: Miss Nyong'o won an Oscar for Best
Supporting Actress for her role in 2013 film 12 Years A Slave . In her
talk, the 30-year-old opened up about how conventional standards of
beauty once affected her self-esteem, reading aloud a letter written
to her by a young girl who viewed her as a role model. 'Dear Lupita,'
reads the letter. 'I think you're really lucky to be this black but
yet this successful in Hollywood overnight. I was just about to buy
Dencia's Whitenicious cream to lighten my skin when you appeared on
the world map and saved me.' 'My heart bled a little when I read those
words,' the actress said through tears, explaining how as a child,
she, too, would pray that she'd one day wake up with lighter skin.
Hypocritical: Dencia is no stranger to controversy; in a February
interview with Ebony, she essentially admitted that Whitenicious is
intended as a skin-lightener, not as a cure for dark spots as it
claims . Perpetuating the problem: 'When you take that picture and you
put a picture of Dencia darker, this is what you're telling people -
the product really works,' she said. 'And guess what? People really
want to buy it' But while the actress saw the letter as a source of
inspiration, Dencia took it as a personal attack. After her angry
tweet at Miss Nyong'o, criticism poured in, with one person tweeting:
'B**** lupita is the new face of Lancôme!! SHE WINS!! And you're just
TRASH [sic]'. In her response, Dencia said of the cosmetics company:
'But they sell bleaching cream tho [sic]'. The pop star is likely
referring to Lancome's Blanc Expert range of cosmetics, which are
actually advertised as 'brighteners' that 'regulate melanin production
and awaken the luminosity of the skin'. And as far as Dencia's claim
that Lancome is a 'white people company', a quick perusal of the
website reveals that it has a number of concealers and foundations in
darker skin tones.<EOS><pad>Dencia's comment is hypocritical
considering she recently courted controversy for marketing 'dark spot
remover' Whitenicious, which is frequently used as a skin-whitening
cream .<EOS>
1.3 Batching with bucketing
As in the previous week, we use bucketing to create batches of data.
1 | # Bucketing to create batched generators. |
1 | # Every execution will result in generation of a different article |
(2, 1024)
1 | # print corresponding integer values |
[ 27 1091 23 46 3873 1248 16013 256 11599 23297 102 68
24308 7 5 1037 1958 320 1477 105 2557 186 4133 28
18175 1348 1287 3 4927 7577 28 8478 10120 19134 7951 364
7317 4990 79 2 393 2 186 8962 2995 9813 4476 3632
2270 5 2 705 2 721 10731 16 186 17136 16 193
54 102 41 1459 320 31 16946 47 2 119 3770 278
355 28 622 263 78 2613 3 312 4543 4 8662 3788
3632 2270 5 6 3048 23524 2 1210 2 1958 320 2033
105 61 2 19134 7951 364 7317 4990 79 24810 17 213
1091 2 931 320 213 16946 47 415 20579 20964 58 1782
863 213 7726 2 213 7599 3938 4133 28 26719 4 752
1480 2868 132 68 583 3898 20579 20964 58 240 197 3
4531 9531 2959 127 132 28 27439 9275 1628 1602 3 8406
5364 11 4927 7577 28 8478 10120 19134 7951 364 7317 4990
79 2 393 2 497 2 186 68 24308 8962 2995 9813
4476 3632 2270 5 2 705 2 721 10731 16 186 17136
16 193 54 809 31 278 78 2613 7511 15 1037 20274
21 379 21549 7150 11 9813 4476 3632 2270 5 80 18649
1496 667 213 17136 45 78 15 882 1838 213 2439 7883
379 27 1147 6 104 6 292 966 43 11850 213 1621
2 931 320 213 13021 4 2 35 22 206 19 5632
213 1018 111 213 2948 186 213 25931 4 3 2713 7801
320 28 6105 32 922 1838 213 6350 141 102 24114 75
78 2613 186 7511 41 2362 2 41 233 3632 2270 5
6 3048 23524 1955 78 28 11261 1797 1782 198 25 92
3787 3103 527 13747 320 213 7599 2 487 159 213 669
27884 4 1622 27872 391 5977 3103 527 2918 186 1472 320
18 46 810 132 28 2439 7726 3898 213 13021 4 127
3 34 31 18649 3347 2 148 19134 7951 364 7317 4990
79 186 9813 4476 3632 2270 5 18 17136 45 78 31
5369 186 19175 5 3 9 2789 25 11203 2 412 25
213 966 186 54 1697 3 12849 14 11 7317 4990 79
12365 146 24810 17 213 1091 4617 27439 9275 1628 4543 4
8662 3788 3632 2270 5 6 3048 23524 2 186 131 4133
28 26719 4 6901 809 213 60 6 1797 6350 809 213
414 8 12370 21 12 186 710 171 864 2362 809 213
1610 379 9 16946 47 415 3357 15581 81 7 5 1431
1890 163 4336 7188 20 78 3632 2270 5 6 3048 23524
7 5 661 2 35 646 25 17926 25290 16741 20 4140
2 213 13021 4 127 3 19134 7951 364 7317 4990 79
1353 3873 1248 11599 23297 2 17260 8041 893 213 5627 527
28 966 2 2439 1740 1524 7726 186 23638 16 24668 21273
204 2 931 320 1882 3 4531 9531 2959 127 19134 7951
364 7317 4990 79 43 9363 4 760 70 35 62 19
2851 2754 103 1353 70 1480 22646 272 7304 132 28 1501
809 213 1881 1610 3 305 1353 475 809 213 16946 47
415 7411 84 78 281 3997 88 226 20934 4 3 9813
4476 3632 2270 5 1353 43 3873 1248 966 17260 8041 16704
464 186 2439 1740 1524 7726 186 1233 320 213 1156 10835
78 281 1696 88 226 20934 4 3 27 924 3729 23
46 4648 1019 3112 1859 809 18235 5333 9141 25733 812 10
1 0 4927 7577 28 8478 10120 19134 7951 364 7317 4990
79 2 393 2 186 8962 2995 9813 4476 3632 2270 5
2 705 2 721 2557 102 4925 1838 28 622 78 2613
1859 16346 27439 6774 1628 312 15 1037 2 4543 4 8662
3788 3632 2270 5 6 3048 23524 2 1210 2 1958 320
1477 105 2 19134 7951 364 7317 4990 79 12365 24810 17
68 16346 27439 6774 1628 305 4133 26719 4 6901 186 864
2362 320 423 68 1955 16346 27439 6774 1628 27 1147 6
104 6 292 2635 11850 213 1621 2104 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0]
Things to notice: - First we see the corresponding values of the words. - The first 1, which represents the <EOS>
tag of the article. - Followed by a 0, which represents a <pad>
tag. - After the first 0 (<pad>
tag) the corresponding values are of the words that are used for the summary of the article. - The second 1 represents the <EOS>
tag for the summary. - All the trailing 0s represent <pad>
tags which are appended to maintain consistent length (If you don't see them then it would mean it is already of max length)
1 | # print the article and its summary |
Article:
A woman has been charged with reckless manslaughter after her
boyfriend's mother tried to stop them fighting and suffered a fatal
heart attack. Claudia Yanira Hernandez Soriano, 25, and Juan Francisco
Martinez Rojas, 28, started punching and scratching each other after
they returned to their Bergen, New Jersey home following a party early
on Monday. When Ana Angelina Rojas-Jovel, 45, tried to break them up,
Hernandez Soriano assaulted the woman, according to the Bergen County
Prosecutor. 'During the assault, the victim apparently suffered a
cardiac event which resulted in her death,' Prosecutor John L.
Molinelli said in a statement. Fight: Claudia Yanira Hernandez
Soriano, 25, above, and her boyfriend Juan Francisco Martinez Rojas,
28, started punching and scratching each other at their home on Monday
when his mother intervened . Injured: Martinez Rojas' booking shot
shows the scratches on his face from the domestic dispute . A seven-
year-old child also witnessed the fight, according to the prosecutor,
but he did not reveal the relationship between the adults and the
youngster. Police responded to a 911 call from the apartment just
after 4am on Monday and when they arrived, they found Rojas-Jovel dead
on a bedroom floor. 'There were no obvious signs of trauma to the
victim, however... the [couple] displayed signs of injury and appeared
to have been involved in a domestic assault,' the prosecutor said. In
their booking photos, both Hernandez Soriano and Martinez Rojas have
scratches on their faces and necks. The pair were interviewed, as were
the child and other residents. Scene: Soriano allegedly then assaulted
the woman, Ana Angelina Rojas-Jovel, and she suffered a cardiac arrest
at the first-floor apartment at the house (pictured) and died before
police arrived at the scene . The Bergen County Medical Examiner's
Office conducted an autopsy on Rojas-Jovel's body, but results were
pending toxicology tests, the prosecutor said. Hernandez Soriano was
charged with manslaughter, endangering the welfare of a child,
domestic violence simple assault and hindering apprehension, according
to authorities. Molinelli said Hernandez Soriano also hid evidence -
but would not detail what it was - which investigators later recovered
in a search at the crime scene. She was held at the Bergen County Jail
on $250,000 bail. Martinez Rojas was also charged with child
endangerment and domestic violence simple assault and sent to the
county jail on $75,000 bail. A court hearing has been scheduled for
Thursday morning at Hackensack Superior Court.<EOS><pad>ClaudiaYanira
Hernandez Soriano, 25, and Juan Francisco Martinez Rojas, 28, started
fighting after returning from a party on Monday morning . When his
mother, Ana Angelina Rojas-Jovel, 45, tried to stop them, Hernandez
Soriano allegedly assaulted her . She suffered cardiac arrest and
police arrived to find her dead . A seven-year-old girl witnessed the
fight .<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa
d><pad>
You can see that the data has the following structure: - [Article] -> <EOS>
-> <pad>
-> [Article Summary] -> <EOS>
-> (possibly) multiple <pad>
The loss is taken only on the summary using cross_entropy as loss function.
# Part 2: Summarization with transformer
Now that we have given you the data generator and have handled the preprocessing for you, it is time for you to build your own model. We saved you some time because we know you have already preprocessed data before in this specialization, so we would rather you spend your time doing the next steps.
You will be implementing the attention from scratch and then using it in your transformer model. Concretely, you will understand how attention works, how you use it to connect the encoder and the decoder.
Now you will implement dot product attention which takes in a query, key, value, and a mask. It returns the output.
Here are some helper functions that will help you create tensors and display useful information: - create_tensor
creates a jax numpy array
from a list of lists. - display_tensor
prints out the shape and the actual tensor.
1 | def create_tensor(t): |
Before implementing it yourself, you can play around with a toy example of dot product attention
without the softmax operation. Technically it would not be dot product attention
without the softmax but this is done to avoid giving away too much of the answer and the idea is to display these tensors to give you a sense of how they look like.
The formula for attention is this one:
\[ \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\ \]
\(d_{k}\) stands for the dimension of queries and keys.
The query
, key
, value
and mask
vectors are provided for this example.
Notice that the masking is done using very negative values that will yield a similar effect to using $-$.
1 | q = create_tensor([[1, 0, 0], [0, 1, 0]]) |
query shape: (2, 3)
[[1 0 0]
[0 1 0]]
key shape: (2, 3)
[[1 2 3]
[4 5 6]]
value shape: (2, 3)
[[0 1 0]
[1 0 1]]
mask shape: (2, 2)
[[ 0.e+00 0.e+00]
[-1.e+09 0.e+00]]
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20query shape: (2, 3)
[[1 0 0]
[0 1 0]]
key shape: (2, 3)
[[1 2 3]
[4 5 6]]
value shape: (2, 3)
[[0 1 0]
[1 0 1]]
mask shape: (2, 2)
[[ 0.e+00 0.e+00]
[-1.e+09 0.e+00]]
1 | q_dot_k = q @ k.T / jnp.sqrt(3) |
query dot key shape: (2, 2)
[[0.57735026 2.309401 ]
[1.1547005 2.8867514 ]]
Expected Output: 1
2
3
4query dot key shape: (2, 2)
[[0.57735026 2.309401 ]
[1.1547005 2.8867514 ]]
1 | masked = q_dot_k + m |
masked query dot key shape: (2, 2)
[[ 5.7735026e-01 2.3094010e+00]
[-1.0000000e+09 2.8867514e+00]]
Expected Output: 1
2
3
4masked query dot key shape: (2, 2)
[[ 5.7735026e-01 2.3094010e+00]
[-1.0000000e+09 2.8867514e+00]]
1 | display_tensor(masked @ v, 'masked query dot key dot value') |
masked query dot key dot value shape: (2, 3)
[[ 2.3094010e+00 5.7735026e-01 2.3094010e+00]
[ 2.8867514e+00 -1.0000000e+09 2.8867514e+00]]
Expected Output: 1
2
3
4masked query dot key dot value shape: (2, 3)
[[ 2.3094010e+00 5.7735026e-01 2.3094010e+00]
[ 2.8867514e+00 -1.0000000e+09 2.8867514e+00]]
In order to use the previous dummy tensors to test some of the graded functions, a batch dimension should be added to them so they mimic the shape of real-life examples. The mask is also replaced by a version of it that resembles the one that is used by trax:
1 | q_with_batch = q[None,:] |
query with batch dim shape: (1, 2, 3)
[[[1 0 0]
[0 1 0]]]
key with batch dim shape: (1, 2, 3)
[[[1 2 3]
[4 5 6]]]
value with batch dim shape: (1, 2, 3)
[[[0 1 0]
[1 0 1]]]
boolean mask shape: (2, 2)
[[ True True]
[False True]]
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19query with batch dim shape: (1, 2, 3)
[[[1 0 0]
[0 1 0]]]
key with batch dim shape: (1, 2, 3)
[[[1 2 3]
[4 5 6]]]
value with batch dim shape: (1, 2, 3)
[[[0 1 0]
[1 0 1]]]
boolean mask shape: (2, 2)
[[ True True]
[False True]]
Instructions: Implement the dot product attention. Concretely, implement the following equation
\[ \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V\tag{1}\ \]
\(Q\) - query, \(K\) - key, \(V\) - values, \(M\) - mask, \({d_k}\) - depth/dimension of the queries and keys (used for scaling down)
You can implement this formula either by trax
numpy (trax.math.numpy) or regular numpy
but it is recommended to use jnp
.
Something to take into consideration is that within trax, the masks are tensors of True/False
values not 0's and \(-\infty\) as in the previous example. Within the graded function don't think of applying the mask by summing up matrices, instead use jnp.where()
and treat the mask as a tensor of boolean values with False
for values that need to be masked and True for the ones that don't.
Also take into account that the real tensors are far more complex than the toy ones you just played with. Because of this avoid using shortened operations such as @
for dot product or .T
for transposing. Use jnp.matmul()
and jnp.swapaxes()
instead.
This is the self-attention block for the transformer decoder. Good luck!
1 | # UNQ_C1 |
1 | DotProductAttention(q_with_batch, k_with_batch, v_with_batch, m_bool) |
DeviceArray([[[0.8496746 , 0.15032545, 0.8496746 ],
[1. , 0. , 1. ]]], dtype=float32)
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39DeviceArray([[[0.8496746 , 0.15032545, 0.8496746 ],
[1. , 0. , 1. ]]], dtype=float32)
```
<a name='2.2'></a>
## 2.2 Causal Attention
Now you are going to implement causal attention: multi-headed attention with a mask to attend only to words that occurred before.
<img src = "causal.png">
In the image above, a word can see everything that is before it, but not what is after it. To implement causal attention, you will have to transform vectors and do many reshapes. You will need to implement the functions below.
<a name='ex02'></a>
### Exercise 02
Implement the following functions that will be needed for Causal Attention:
- <span style='color:blue'> compute_attention_heads </span>: Gets an input $x$ of dimension (batch_size, seqlen, n_heads $\times$ d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size $\times$ n_heads, seqlen, d_head).
- <span style='color:blue'> dot_product_self_attention </span>: Creates a mask matrix with `False` values above the diagonal and `True` values below and calls DotProductAttention which implements dot product self attention.
- <span style='color:blue'> compute_attention_output </span>: Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (batch_size, seqlen, n_heads $\times$ d_head). These operations concatenate (stack/merge) the heads.
Next there are some toy tensors which may serve to give you an idea of the data shapes and opperations involved in Causal Attention. They are also useful to test out your functions!
```python
tensor2d = create_tensor(q)
display_tensor(tensor2d, 'query matrix (2D tensor)')
tensor4d2b = create_tensor([[q, q], [q, q]])
display_tensor(tensor4d2b, 'batch of two (multi-head) collections of query matrices (4D tensor)')
tensor3dc = create_tensor([jnp.concatenate([q, q], axis = -1)])
display_tensor(tensor3dc, 'one batch of concatenated heads of query matrices (3d tensor)')
tensor3dc3b = create_tensor([jnp.concatenate([q, q], axis = -1), jnp.concatenate([q, q], axis = -1), jnp.concatenate([q, q], axis = -1)])
display_tensor(tensor3dc3b, 'three batches of concatenated heads of query matrices (3d tensor)')
query matrix (2D tensor) shape: (2, 3)
[[1 0 0]
[0 1 0]]
batch of two (multi-head) collections of query matrices (4D tensor) shape: (2, 2, 2, 3)
[[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]]
one batch of concatenated heads of query matrices (3d tensor) shape: (1, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
three batches of concatenated heads of query matrices (3d tensor) shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
It is important to know that the following 3 functions would normally be defined within the CausalAttention
function further below.
However this makes these functions harder to test. Because of this, these functions are shown individually using a closure
(when necessary) that simulates them being inside of the CausalAttention
function. This is done because they rely on some variables that can be accessed from within CausalAttention
.
Support Functions
compute_attention_heads : Gets an input \(x\) of dimension (batch_size, seqlen, n_heads \(\times\) d_head) and splits the last (depth) dimension and stacks it to the zeroth dimension to allow matrix multiplication (batch_size \(\times\) n_heads, seqlen, d_head).
For the closures you only have to fill the inner function.
1 | # UNQ_C2 |
1 | display_tensor(tensor3dc3b, "input tensor") |
input tensor shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
output tensor shape: (6, 2, 3)
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30input tensor shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
output tensor shape: (6, 2, 3)
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
dot_product_self_attention : Creates a mask matrix with False
values above the diagonal and True
values below and calls DotProductAttention which implements dot product self attention.
1 | # UNQ_C3 |
1 | dot_product_self_attention(q_with_batch, k_with_batch, v_with_batch) |
DeviceArray([[[0. , 1. , 0. ],
[0.8496746 , 0.15032543, 0.8496746 ]]], dtype=float32)
Expected Output: 1
2DeviceArray([[[0. , 1. , 0. ],
[0.8496746 , 0.15032543, 0.8496746 ]]], dtype=float32)
compute_attention_output : Undoes compute_attention_heads by splitting first (vertical) dimension and stacking in the last (depth) dimension (batch_size, seqlen, n_heads \(\times\) d_head). These operations concatenate (stack/merge) the heads.
1 | # UNQ_C4 |
1 | display_tensor(result_cah, "input tensor") |
input tensor shape: (6, 2, 3)
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
output tensor shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30input tensor shape: (6, 2, 3)
[[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]
[[1 0 0]
[0 1 0]]]
output tensor shape: (3, 2, 6)
[[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]
[[1 0 0 1 0 0]
[0 1 0 0 1 0]]]
Causal Attention Function
Now it is time for you to put everything together within the CausalAttention
or Masked multi-head attention function:
Instructions: Implement the causal attention. Your model returns the causal attention through a \(tl.Serial\) with the following:
- tl.Branch : consisting of 3 [tl.Dense(d_feature), ComputeAttentionHeads] to account for the queries, keys, and values.
- tl.Fn: Takes in dot_product_self_attention function and uses it to compute the dot product using \(Q\), \(K\), \(V\).
- tl.Fn: Takes in compute_attention_output_closure to allow for parallel computing.
- tl.Dense: Final Dense layer, with dimension
d_feature
.
Remember that in order for trax to properly handle the functions you just defined, they need to be added as layers using the tl.Fn()
function.
1 | # UNQ_C5 |
1 | # Take a look at the causal attention model |
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Expected Output: 1
2
3
4
5
6
7
8
9
10Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
2.3 Transformer decoder block
Now that you have implemented the causal part of the transformer, you will implement the transformer decoder block. Concretely you will be implementing this image now.
To implement this function, you will have to call the CausalAttention
or Masked multi-head attention function you implemented above. You will have to add a feedforward which consists of:
- tl.LayerNorm : used to layer normalize
- tl.Dense : the dense layer
- ff_activation : feed forward activation (we use ReLu) here.
- tl.Dropout : dropout layer
- tl.Dense : dense layer
- tl.Dropout : dropout layer
Finally once you implement the feedforward, you can go ahead and implement the entire block using:
tl.Residual : takes in the tl.LayerNorm(), causal attention block, tl.dropout.
tl.Residual : takes in the feedforward block you will implement.
### Exercise 03 Instructions: Implement the transformer decoder block. Good luck!
1 | # UNQ_C6 |
1 | # Take a look at the decoder block |
[Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
Add_in2
], Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Relu
Dropout
Dense_512
Dropout
]
]
Add_in2
]]
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33[Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
Add_in2
], Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Relu
Dropout
Dense_512
Dropout
]
]
Add_in2
]]
## 2.4 Transformer Language Model
You will now bring it all together. In this part you will use all the subcomponents you previously built to make the final model. Concretely, here is the image you will be implementing.
### Exercise 04 Instructions: Previously you coded the decoder block. Now you will code the transformer language model. Here is what you will need.
- positional_enconder - a list containing the following layers:
- A list of
n_layers
decoder blocks. - tl.Serial: takes in the following layers or lists of layers:
- tl.ShiftRight: : shift the tensor to the right by padding on axis 1.
- positional_encoder : encodes the text positions.
- decoder_blocks : the ones you created.
- tl.LayerNorm : a layer norm.
- tl.Dense : takes in the vocab_size.
- tl.LogSoftmax : to predict.
Go go go!! You can do it :)
1 | # UNQ_C7 |
1 | # Take a look at the Transformer |
Serial[
ShiftRight(1)
Embedding_33300_512
Dropout
PositionalEncoding
Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
Add_in2
]
Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Relu
Dropout
Dense_512
Dropout
]
]
Add_in2
]
LayerNorm
Dense_33300
LogSoftmax
]
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43Serial[
ShiftRight(1)
Embedding_33300_512
Dropout
PositionalEncoding
Serial[
Branch_out2[
None
Serial[
LayerNorm
Serial[
Branch_out3[
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
[Dense_512, AttnHeads]
]
DotProductAttn_in3
AttnOutput
Dense_512
]
Dropout
]
]
Add_in2
]
Serial[
Branch_out2[
None
Serial[
LayerNorm
Dense_2048
Relu
Dropout
Dense_512
Dropout
]
]
Add_in2
]
LayerNorm
Dense_33300
LogSoftmax
]
Now you are going to train your model. As usual, you have to define the cost function, the optimizer, and decide whether you will be training it on a gpu
or cpu
. In this case, you will train your model on a cpu for a few steps and we will load in a pre-trained model that you can use to predict with your own words.
You will now write a function that takes in your model and trains it. To train your model you have to decide how many times you want to iterate over the entire data set. Each iteration is defined as an epoch
. For each epoch, you have to go over all the data, using your training iterator.
### Exercise 05 Instructions: Implement the train_model
program below to train the neural network above. Here is a list of things you should do:
- Create the train task by calling
trax.supervised.training.TrainTask
and pass in the following:- labeled_data = train_gen
- loss_fn = tl.CrossEntropyLoss()
- optimizer = trax.optimizers.Adam(0.01)
- lr_schedule = lr_schedule
- Create the eval task by calling
trax.supervised.training.EvalTask
and pass in the following:- labeled_data = eval_gen
- metrics = tl.CrossEntropyLoss() and tl.Accuracy()
- Create the training loop by calling
trax.supervised.Training.Loop
and pass in the following:- TransformerLM
- train_task
- eval_task = [eval_task]
- output_dir = output_dir
You will be using a cross entropy loss, with Adam optimizer. Please read the Trax documentation to get a full understanding.
The training loop that this function returns can be runned using the run()
method by passing in the desired number of steps.
1 | from trax.supervised import training |
Notice that the model will be trained for only 10 steps.
Even with this constraint the model with the original default arguments took a very long time to finish. Because of this some parameters are changed when defining the model that is fed into the training loop in the function above.
1 | # Should take around 1.5 minutes |
Step 1: Ran 1 train steps in 9.11 secs
Step 1: train CrossEntropyLoss | 10.41297626
Step 1: eval CrossEntropyLoss | 10.41586781
Step 1: eval Accuracy | 0.00000000
Step 10: Ran 9 train steps in 58.21 secs
Step 10: train CrossEntropyLoss | 10.41278458
Step 10: eval CrossEntropyLoss | 10.41440201
Step 10: eval Accuracy | 0.00000000
### 4.1 Loading in a trained model
In this part you will evaluate by loading in an almost exact version of the model you coded, but we trained it for you to save you time. Please run the cell below to load in the model.
As you may have already noticed the model that you trained and the pretrained model share the same overall architecture but they have different values for some of the parameters:
Original (pretrained) model:
TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8,
dropout=0.1, max_len=4096, ff_activation=tl.Relu)
Your model:
TransformerLM(d_model=4, d_ff=16, n_layers=1, n_heads=2)
Only the parameters shown for your model were changed. The others stayed the same.
1 | # Get the model architecture |
# Part 5: Testing with your own input
You will now test your input. You are going to implement greedy decoding. This consists of two functions. The first one allows you to identify the next symbol. It gets the argmax of the output of your model and then returns that index.
### Exercise 06 Instructions: Implement the next symbol function that takes in the cur_output_tokens and the trained model to return the index of the next word.
1 | # UNQ_C9 |
1 | # Test it out! |
'The'
Expected Output: 1
'The'
Now you will implement the greedy_decode algorithm that will call the next_symbol
function. It takes in the input_sentence, the trained model and returns the decoded sentence.
Instructions: Implement the greedy_decode algorithm.
1 | # UNQ_C10 |
1 | # Test it out on a sentence! |
It was a sunny day when I went to the market to buy some flowers. But
I only found roses, not tulips.
:
: I
: I just
: I just found
: I just found ros
: I just found roses
: I just found roses,
: I just found roses, not
: I just found roses, not tu
: I just found roses, not tulips
: I just found roses, not tulips
: I just found roses, not tulips.
: I just found roses, not tulips.<EOS>
: I just found roses, not tulips.<EOS>
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14:
: I
: I just
: I just found
: I just found ros
: I just found roses
: I just found roses,
: I just found roses, not
: I just found roses, not tu
: I just found roses, not tulips
: I just found roses, not tulips
: I just found roses, not tulips.
: I just found roses, not tulips.<EOS>
: I just found roses, not tulips.<EOS>
1 | # Test it out with a whole article! |
Expected Output: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21Jordan
Jordan Ful
Jordan Fulcol
Jordan Fulcoly
Jordan Fulcoly,
Jordan Fulcoly, Wayne
Jordan Fulcoly, Wayne Dre
Jordan Fulcoly, Wayne Drexe
Jordan Fulcoly, Wayne Drexel
Jordan Fulcoly, Wayne Drexel,
.
.
.
Final summary:
Jordan Fulcoly, Wayne Drexel, Tyler Carroll and Connor Carroll were
suspended for one day. Four students were suspended for one day
because they allegedly did not heed to warnings that the 'Tebowing'
craze was blocking the hallway and presenting a safety hazard to
students.<EOS>
Congratulations on finishing this week's assignment! You did a lot of work and now you should have a better understanding of the encoder part of Transformers and how Transformers can be used for text summarization.
Keep it up!