Written by Ethan Smith

Table of Contents

Screenshot 2024-02-23 at 6.30.38 PM.png

Intro and Previous Methods

The motivation for this work was after toying around with ToMeSD in April or so of 2023 and finding that their method did not yield the speedups I had been expecting, at least at token reduction ratios that did not decimate quality. https://github.com/dbolya/tomesd/issues/19#issuecomment-1507593483.

For some background, their method is based on the original Token Merging: Your VIT but Faster, where you can increase throughput of an attention network by merge tokens with the highest cosine similarity to each other. By merge, I mean a conventional average.

The idea is that it is pretty well established that global attention can often be a bit overkill. You can have a lot of tokens that contribute very little to the output, or in our case, you can have a number of tokens that are near identical to each other, not really adding any information.

In a toy case, if we perform attention with two tokens that are identical lets call it X. Our output will be 0.5X + 0.5X. And then three tokens will give each identical token a weight of 0.33. Regardless of what happens the output is 1X.

<aside> 💡 (Note: in realistic cases where we have many other tokens of different values, duplicate tokens can have other affects like downweighing the others, so its not entirely equivalent. Potentially this is a way to improve ToMe or our method by accounting for this change!

Edit: this may inadvertently be part of the success of our method, because the subsampling we do is uniform, all regions of tokens are reduced evenly.)

Here is an toy example of what I mean, simply duplicating the same items flattens the distribution. but the total mass to each value remains the same.

Merging by similarity may favor some regions moreso than others, inadvertently increasing or decreasing the influence of certain tokens. This may be why ToMe differs more from the original

Screenshot 2024-02-27 at 8.31.03 PM.png

</aside>

The idea is that we can simply take the average of tokens who have similar values and then treat that as a single token, thus decreasing our sequence length.

If we can even cut the number of tokens in half, because of O(nˆ2) scaling, that makes the operation theoretically 4x as efficient.