Fastformer: Additive Attention Can Be All You Need (Machine Learning Research Paper Explained)

  Рет қаралды 28,063

Yannic Kilcher

Yannic Kilcher

Күн бұрын

Пікірлер: 66
@YannicKilcher
@YannicKilcher 3 жыл бұрын
OUTLINE: 0:00 - Intro & Outline 2:15 - Fastformer description 5:20 - Baseline: Classic Attention 10:00 - Fastformer architecture 12:50 - Additive Attention 18:05 - Query-Key element-wise multiplication 21:35 - Redundant modules in Fastformer 25:00 - Problems with the architecture 27:30 - Is this even attention? 32:20 - Experimental Results 34:50 - Conclusion & Comments
@Mutual_Information
@Mutual_Information 3 жыл бұрын
Naming papers without "all you need" is all I need
@CharlesVanNoland
@CharlesVanNoland 3 жыл бұрын
I didn't even know I needed this until you told me you needed it. :D
@justinwhite2725
@justinwhite2725 3 жыл бұрын
It's a meme. Even papers that don't have it get nicknamed all you need.
@mattizzle81
@mattizzle81 3 жыл бұрын
Your comment was all I needed
@tchlux
@tchlux 3 жыл бұрын
I think the main struggle people are encountering is that the "attention" mechanism in a transformer actually does two things. (1) It ranks and combines an arbitrary number of things by value, allowing the model to "filter out" items that are irrelevant for a certain task; and (2) it allows every element in an arbitrary length list to be updated based on the "context" of what else is included in the list, which enables selective suppression of information. On the mathematical side, there's no reason these operations cannot be separated out. I suspected someone would publish this paper eventually when they realized you can do the "additive ranking" process separately, and I predict that the next paper someone will publish will be *"Adderall Transformer: Improving focus with nonlinear attention"*. In that paper, the authors will show that using an MLP to do step (2) (the "contextual update") can greatly reduce the number of layers you need in a transformer to get the same performance. I might be wrong, but it seems like people are going in circles here because they're not thinking about what vector operations are being performed at each step. If you haven't yet, take a second to try and draw every operation in a neural network in 2 or 3 dimensions so you can actually visualize what is going on. It reveals a lot of avenues for research. 🤓 (well-spaced initialization schemes can accelerate early training, avoiding "dead neurons" during training by measuring distance between activations, a small but fundamental flaw in the Adam gradient curvature estimation, accelerated training with representative batch back propagation, ...)
@kokakoriful
@kokakoriful 3 жыл бұрын
Stolen idea for paper :V
@aspergale9836
@aspergale9836 3 жыл бұрын
Can you elaborate? - How do you imagine an MLP doing the context-based update per values? - """avoiding "dead neurons" during training by measuring distance between activations""" -- And then biasing them somehow? - """a small but fundamental flaw in the Adam gradient curvature estimation""" -- Can you give a hint? Only major flaw I'm aware of was in the weight decay term, which led to AdamW. - """accelerated training with representative batch back propagation""" -- Can you elaborate? Thanks!
@tchlux
@tchlux 3 жыл бұрын
@@kokakoriful to be clear, it's not stealing.. If you actually do this experiment and publish it, that's just science! And I'd thank you for the work. 😁 I can come up with ideas far faster than I can realize experimental results.
@tchlux
@tchlux 3 жыл бұрын
@@aspergale9836 ​ yeah, let me try to explain. > How do you imagine an MLP doing the context-based update per values? Two major steps here. Your input tokens vector representations should be part of the training procedure, same as for any categorical embedding. Now, let's say you've got a D-dimensional token, and you have N of them is your max input size. Instead of doing attention with purely linear operators (matrix multiplications), for every input token you apply an MLP that takes 2×D inputs (the concatenation of the two token embeddings) and produces D+1 outputs where the first D become the new representation of that token and the +1 is the "value" of that output. This is an N^2 operation and you will apply the same MLP to every pair of input tokens. Now take the weighted sum of values × token representations as the new token at each point in the sequence. Even better, learn the "value" function with a separate MLP from the "contextual token update" function, (value function takes 2×D inputs and produces 1 output, contextual update takes 2×D inputs and produces a D vector as output) which allows value and actual position in token space to independently represented internally within the model. I think the biggest problem someone would encounter here is having a hard time initializing the networks so that you can get convergence. My opinion on the best solution here is what I call "flexure activation functions", where you have a ReLU but the actual multiplicative factor between the slopes on the left and right side of 0 is also a learnable parameter. If you are failing to converge, you push the flexure parameters to 1 and you end up with a purely linear function (convex loss landscape ➞ guaranteed to converge). > And then biasing them somehow? I tend to think of this in terms of "how spread out are the directions being captured by the vectors in a given layer?" If you actually dig into it, then I suspect you'll find that anywhere between 10% to 50% of the activations in any given layer of a ReLU MLP contribute almost nothing to the final loss of the model. This is what I've seen when I visualize the contribution of nodes to total loss on my own problems (via forced dropout). What can you do with these neurons then, since they're not contributing to the loss? You can either delete them and save time (meh.. changing structure..) or you can force them into a new / more useful direction! I suggest the latter, pick a direction that is away from other neurons *and* in the direction of the gradient (zero out the components of the gradient in the direction towards nearest other neurons). Also, incorporate the "flexure" method I described above to force neurons with terrible shift terms to be able to still capture information. Another tactic would be just to reset the shift term to be the median of the values produced by the neuron *before* applying the activation function. > Can you give a hint? Only major flaw I'm aware of was in the weight decay term, which led to AdamW. Fundamentally Adam works by assuming that loss landscapes are shaped like a quadratic function (a bowl), but how is the second derivate (squared term in the axis-aligned quadratic representation) estimated in Adam? How would you estimate it differently if you actually assumed the loss landscape was a pure axis-aligned quadratic? > "... representative batch back propagation" -- Can you elaborate? At any point in training (for most real problems), the distribution of loss values at all points in your data set will generally be dominated by a fraction of all data. This is true for most distributions in the real world, most variance in any data >3 dimensions is usually captured by the first ~10-30% of principal components (however there are exceptions). Acknowledging that this is almost always true, one would suspect that most of the "gradient" is coming from a 10-30% subset of the data. Something you could do to accelerate training with minimal loss of information is to compute the distance between data points in the last layer of the network times the difference in gradient value. There are many fast ways to get a well-spaced sample, but the idea is that you can probably throw away ~80% (or much more) of your training data at any given point in time and still get a nearly identical gradient. If you throw away 80%, then your backpropagation is immediately 5× faster than it was before.. Clearly there are time losses to doing the "well spaced" / "representative" estimation, so there's research to be done here. But, it's certainly a viable strategy that many applications could benefit from and that's clear from the mathematics alone.
@aspergale9836
@aspergale9836 3 жыл бұрын
​ @Thomas Lux This all sounds very interesting! If I may be so bold as to solicit more of your experience: - What literature (papers, books, articles) would you suggest on these topics? Not in a general sense. Rather, any _especially_ relevant paper/other that sparked the investigation or question in your mind. For each of the points you mentioned. > well-spaced initialization schemes can accelerate early training - Can you define briefly what "well-spaced" refers to? Currently popular initialization schemes already try to be "airy" whether based on a normal or a uniform distribution. > If you haven't yet, take a second to try and draw every operation in a neural network in 2 or 3 dimensions so you can actually visualize what is going on - Any examples that you can point to? Do you mean in the style of "The Illustrated Transformer", or in the style of "dimensional analysis" in Physics? Or? I know the sentence is clear and makes _sense_ in a way that I do have an idea of what I'd do trying to do this. But the devil is in the details, especially in investigative processes that can take quite a bit of cumulative wisdom to perfect. - "Flexure activation functions" doesn't seem to be an established term (as you did say). I understood it to be basically PReLU (Parameteric Relu) but the right/positive side is also parametric. That's an interesting idea. Certainly, the "next" linear layer can be updated so all weights reading a particular input neuron are multiplied by a factor as needed, but that's slower in terms of realization (gradient + update steps until the factor is applied) than changing just the PReLU-2's weight, and so is possibly less stable. > but how is the second derivate (squared term in the axis-aligned quadratic representation) estimated in Adam? How would you estimate it differently if you actually assumed the loss landscape was a pure axis-aligned quadratic? - I'm afraid you lost me :'D What's the difference between an "axis-aligned quadratic representation" (which I'm _guessing_ refers to having no bias, and so "touches" each of the base axes), and a "**pure** axis-aligned quadratic"? What's "pure" here? > or you can force them into a new / more useful direction! I suggest the latter, pick a direction that is away from other neurons and in the direction of the gradient (zero out the components of the gradient in the direction towards nearest other neurons). - The suggestions here also sound very interesting (e.g. the post-training step suggested at the end to revive dead neurons by debiasing them). But I'm confused on the in-training suggestion (as in the quote). E.g. What does the "direction of a neuron" refer to? Naively, I'd think it can have only 2 directions: positive and negative. But maybe you have something more elaborate in mind? Like, I don't know, clustering the _magnitudes_ of each neuron? Then trying to make the other neurons find an empty "spot" in their expected magnitudes? I'm probably talking nonsense... > There are many fast ways to get a well-spaced sample, but the idea is that you can probably throw away ~80% (or much more) of your training data at any given point in time and still get a nearly identical gradient. - Something like clustering the data then sampling? Thanks a lot!
@mgostIH
@mgostIH 3 жыл бұрын
In my opinion a big deal of this paper (if we had more experiments and ablations, as you mentioned) could be that it shows how in practice machine learning models need very little in terms of abstraction and ideas but just require information to be carried around in a learnable way. In this way I think it's closer to MLP-Mixer than attention, working with the core idea that you need learned transformations per vector and learned pooling for each layer, here the attention is only needed for the latter and may be replaceable aswell.
@linminhtoo
@linminhtoo 3 жыл бұрын
Yannic, I have noticed that your explanations have become a lot more intuitive & easy to understand as time goes by (they were already easy to understand from the beginning). Really great to see that!
@DamianReloaded
@DamianReloaded 3 жыл бұрын
It just occurred to me it would be interesting to watch additional videos on papers as a follow up regarding other researcher's implementation of that paper. "Papers implemented" "Paper does it actually work" (?)
@itzikgutzcha4779
@itzikgutzcha4779 3 жыл бұрын
Thank you, I find it difficult reading and understanding these papers and it helps a lot to watch your videos
@Xaelum
@Xaelum 3 жыл бұрын
It was also mentioned in the paper that they pre-initialized the word embeddings using Glove vectors. Why Glove? Did it converge faster? Did it yield better performance than when let to optimize by itself? The results of the paper are very good performance-wise, but there are SO many unexplained decisions that it feels like they rushed to get this public. I hope the follow-up papers (and hopefully code) help us learn more about these things.
@luke.perkin.online
@luke.perkin.online 3 жыл бұрын
You were a little harsh. Yes, they chose a dumb name. Yes, they're not really QKV, or attention. But being on par with state of the art without those things, and ~one order of magnitude less training and inference time, is good. It makes intuitive sense to project each element to a different high dimensional space then learn global patterns that span all spaces, a bit like conv channels. Hopefully a follow up paper with ablation and multiple columns will come soon!
@mgostIH
@mgostIH 3 жыл бұрын
That query + values connection done at the end might be similar to the residual connections in transformers, but I agree that it might be something added after the network failed to converge (and may aswell work in other places)
@harry1010
@harry1010 3 жыл бұрын
What’s faster? Fastformer, or Yannic breaking down novel algorithms?!
@ophir1080
@ophir1080 3 жыл бұрын
Hi :) A small clarification would be well appreciated: In practice, the Query and Key vectors are derived as a result of linear projections (by W_q and W_k respectively). Therefore, I don’t quite understand your interpretation (their “jobs” in the mechanism). Is this your intuition (which btw I extremely adore!)? Thanks!
@L9X
@L9X 3 жыл бұрын
So the linear projection is unlikely to be orthogonal, i.e. the 'intensity' of the information in a vector after projected will be changed, so when passing through the softmax function followed by multiplication and addition only the most 'intense' information is left in the resulting query/key vector. So the 'job' of the projection is to decide which *type* of information from all the input tokens is relevant, and the resulting vector from the additive attention is a vector that encapsulates how intense that type of information is across all input values. If you look into squeeze and excite networks a similar concept is performed for CNNs, where the squeeze and excite operator produces an attention map per channel, not per pixel. To me fastformer seems like squeeze and excite on steroids, and with enough layers global and local context can be determined even if we are 'grouping' all elements into a single key/query vector through summnation.
@felipemello1151
@felipemello1151 3 жыл бұрын
I believe that they add the query vector to the end result because query and value are apparently the same (they share weights). So it seems to work as a residual connection.
@herp_derpingson
@herp_derpingson 3 жыл бұрын
19:30 P vector . 20:33 In normal transformer each word paid attention to each word. Now each word pays attention to the entire sentence, which is just the sum of all the queries of the words. . 25:50 Once in a while you do see papers which apply a few engineering tricks to make things work, but this q addition was the craziest of all. . 26:00 The Q K V column thing is a good catch. I didnt notice it. Just 2 columns are sufficient for intercommunication. . There are so many papers trying to make the transformer sub-quadratic. However as a human when I read a book, I dont store the entire context in my memory. Instead I remember which page I need to jump back to to get a refresher. I think in order to make transformers fast, we need to make the context pointers instead of removing the quadratic attention.
@ЗакировМарат-в5щ
@ЗакировМарат-в5щ 3 жыл бұрын
I think the reason behind the q, k, and v (in this fastformer model) is that you can still have 3 conceptually separate streams of information. If you think a little you can also remove keys from transformer, why actually not? The only reason for "not" is in thinking pattern of 3 separate streams of information.
@aBigBadWolf
@aBigBadWolf 3 жыл бұрын
yannic, schmidhuber has a followup paper where they go "beyond linear transformers". Looks like a generalisation that suits a video?
@drhilm
@drhilm 3 жыл бұрын
Is this basically a complex MLP with residual connection?
@fiNitEarth
@fiNitEarth 3 жыл бұрын
19:34 amazing
@jeremykothe2847
@jeremykothe2847 3 жыл бұрын
100k gogogo!
@L9X
@L9X 3 жыл бұрын
In regards to 24:50 where you say you don't know why that is done, I think it's related to how in the AIAYN transformer decoder the query input is always added to the MHA output, and the query input itself is a result of causal MHSA.
@brokecoder
@brokecoder 3 жыл бұрын
I think Peceiver architecture from DeepMind was a neater way of getting across the quadratic complexity of attention than the Fast-Transformer introduced in this paper.
@muhammadbilal902
@muhammadbilal902 3 жыл бұрын
Yannic the way you pronounce those chinese names :-D
@김건일-g4g
@김건일-g4g 3 жыл бұрын
Technically I understand until the end of design when we add the query vectors with the sum of value vectors...
@MrMIB983
@MrMIB983 3 жыл бұрын
Is there a repo?
@NGNBoone
@NGNBoone 3 жыл бұрын
If they called it a feed forward network with context-aware modulations or some-such thing, would you have done a video on it? The "key, query, value" names in the original Transformer paper were somewhat arbitrary as well... You make a good case for why it's not properly labeled a Transformer, but.... idk.. LucidRains has a Github repo implementing the network. What DOES the network do if you remove that skip connection?
@loicgrossetete9570
@loicgrossetete9570 3 жыл бұрын
I think you are onto something with the column repetition ! It might help to do multiple agreggation, it would be like modifing the words interpretation in a sentence based on the global context then update this global context to better understand the word which lead to a slightly different context and so on
@deoabhijit5935
@deoabhijit5935 3 жыл бұрын
Whoa so much going on with trnasformers!
@stecarretti
@stecarretti 3 жыл бұрын
Can someone please tell me why in the paper they write that the global query vector has shape (d,d), and not only d? They also write that Q,K and V matrices have shape (d,d) instead of (N,d)
@stonecastle858
@stonecastle858 3 жыл бұрын
Hi Yannic, in your view, which of the proposed solutions (so far) to address the self-attention computational costs is most successful?
@srh80
@srh80 3 жыл бұрын
Wild guess for query skip connection is probably the gradient flow has bottlenecks with the two levels of aggregations.
@hecao634
@hecao634 3 жыл бұрын
27:01 yeah it even sounds like an RNN?
@jean-baptistedelabroise5391
@jean-baptistedelabroise5391 3 жыл бұрын
I don't know why we just don't do windowed attention like any token attend to the previous and next n tokens. and maybe just the [CLS] token in BERT model case attend to the whole thing
@srh80
@srh80 3 жыл бұрын
Hey Yannic, you should stitch together multiple similar videos to create compilation videos maybe about 3 to 5 hrs long. Ideal for going to sleep to some viewers and more views for ur channel.
@srh80
@srh80 3 жыл бұрын
Just curious, has anyone tried to set number of queries and keys be log(num tokens) and keep everything else be like original transformer?
@bionhoward3159
@bionhoward3159 3 жыл бұрын
I'd love to see more videos on active inference, causality, analogy, abstraction, GNNs, RL, autoML, genetic programming, and other more advanced ML topics. How do you model conditional probability with ML? How do you make an eligibility trace for DNNs?
@Phenix66
@Phenix66 3 жыл бұрын
They just couldn't resist the title, could they?^^
@XOPOIIIO
@XOPOIIIO 3 жыл бұрын
I think AGI is almost here, you just have to aggregate what is already known and make a few simple insights. It just nobody is able to read all the papers and find out everything that is already known.
@조성민-y9n
@조성민-y9n 3 жыл бұрын
that's not true
@XOPOIIIO
@XOPOIIIO 3 жыл бұрын
@@조성민-y9n How do you know that?
@ThiagoSTeixeir4
@ThiagoSTeixeir4 3 жыл бұрын
"Memes Is All You Need."
@swordwaker7749
@swordwaker7749 3 жыл бұрын
Nystromer already works pretty well...
@Notshife
@Notshife 3 жыл бұрын
If the little learned feature detector was done 6 times (still a reduction from N) then it could learn who what where when why how? :P
@srh80
@srh80 3 жыл бұрын
"All you need" is all you need.
@ArjunKumar123111
@ArjunKumar123111 3 жыл бұрын
I'm gonna put all my paper titles as X can be all you need lmao
@djfl58mdlwqlf
@djfl58mdlwqlf 3 жыл бұрын
this is dang fast!! (upload I mean)
@wolfisraging
@wolfisraging 3 жыл бұрын
In the title for the papers of these days, "all you need" is all you need.
@konghong3885
@konghong3885 3 жыл бұрын
When the paper already did the clickbait for you
@guidoansem
@guidoansem 3 жыл бұрын
algo
@minos99
@minos99 3 жыл бұрын
The Transformer is not a sacred text. Feel free to butcher, add and subtract from it. I welcome this self attention heretics but at the same time demand ablation and comparison with peers based on Perplexity.
@Lee-vs5ez
@Lee-vs5ez 3 жыл бұрын
.....
@paxdriver
@paxdriver 3 жыл бұрын
Lol "transformers all you need" is soooo overused
@jadtawil6143
@jadtawil6143 3 жыл бұрын
"there is no qualitative difference between the keys and queries".. hence it is a trash paper.. !
ССЫЛКА НА ИГРУ В КОММЕНТАХ #shorts
0:36
Паша Осадчий
Рет қаралды 8 МЛН
Air Sigma Girl #sigma
0:32
Jin and Hattie
Рет қаралды 45 МЛН
УНО Реверс в Амонг Ас : игра на выбывание
0:19
Фани Хани
Рет қаралды 1,3 МЛН
Ful Video ☝🏻☝🏻☝🏻
1:01
Arkeolog
Рет қаралды 14 МЛН
2024's Biggest Breakthroughs in Math
15:13
Quanta Magazine
Рет қаралды 516 М.
Visualizing transformers and attention | Talk for TNG Big Tech Day '24
57:45
2 Years of My Research Explained in 13 Minutes
13:51
Edan Meyer
Рет қаралды 58 М.
Wavelets: a mathematical microscope
34:29
Artem Kirsanov
Рет қаралды 655 М.
Attention in transformers, visually explained | DL6
26:10
3Blue1Brown
Рет қаралды 1,9 МЛН
ССЫЛКА НА ИГРУ В КОММЕНТАХ #shorts
0:36
Паша Осадчий
Рет қаралды 8 МЛН