Description
Describe the feature and the current behavior/state.
FAIR have a cool paper where they introduce Product-Key Memory Layers - these are layers that can add a huge number of parameters (100M-1B) to a network with a very minimal compute overhead.
Unfortunately, implementing them efficiently depends on the EmbeddingBag layer from Pytorch. This layer basically does a gather op followed by a weighted sum across the final dimension of the gather indices.
It is trivial to implement this op as a composition of two or three ops in Tensorflow, but doing so requires you to materialize the output of the gather, which in the case of Product-Key Memory layers is enormous, and usually blows out my GPU RAM. By combining these ops into a single efficient call, EmbeddingBag avoids ever materializing the extremely large pre-sum gather output. There's no efficient way to do the same in Tensorflow without a custom op.
I've already gotten a CUDA and (single-threaded) CPU implementation of EmbeddingBag working locally using the custom-op repo and associated docker image. I've verified correctness by comparing outputs and gradients to those from the manual composition of ops, and speed and memory usage are vastly improved. I could also contribute a TF implementation of the Product-Key Memory layer itself if desired.
Relevant information
- Are you willing to contribute it (yes/no): yes
- Are you willing to maintain it going forward? (yes/no): yes
- Is there a relevant academic paper? (if so, where): https://arxiv.org/abs/1907.05242
- Is there already an implementation in another framework? (if so, where): Yes, EmbeddingBag is already a PyTorch layer
- Was it part of tf.contrib? (if so, where):
Which API type would this fall under (layer, metric, optimizer, etc.)
Layer
Who will benefit with this feature?
People who want to squeeze loads of parameters into their model while maintaining fast throughput and aren't worried about overfitting. The paper used it for big autoregressive NLP Transformers, but I suspect you could deploy it in a lot of other places too.
Any other info.
I have only implemented the portions of EmbeddingBag necessary for Product-Key Memory layers.