Skip to yearly menu bar Skip to main content


Poster

MatFormer: Nested Transformer for Elastic Inference

Fnu Devvrit · Sneha Kudugunta · Aditya Kusupati · Tim Dettmers · Kaifeng Chen · Inderjit Dhillon · Yulia Tsvetkov · Hannaneh Hajishirzi · Sham Kakade · Ali Farhadi · Prateek Jain

East Exhibit Hall A-C #2507
[ ]
Wed 11 Dec 4:30 p.m. PST — 7:30 p.m. PST

Abstract:

Foundation models are applied in a broad spectrum of settings with different inference constraints, from massive multi-accelerator clusters to resource-constrained standalone mobile devices. However, the substantial costs associated with training these models often limit the number of unique model sizes that can be offered. Consequently, practitioners are compelled to select a model that may not be optimally aligned with their specific latency and cost requirements. We present MatFormer, a novel Transformer architecture designed to provide elastic inference across diverse deployment constraints. MatFormer achieves this by incorporating a nested Feed Forward Network (FFN) block structure within a standard Transformer model. During training, we optimize the parameters of multiple nested FFN blocks with varying sizes, enabling the extraction of hundreds of accurate smaller models without incurring additional computational costs. We empirically validate the efficacy of MatFormer across different model classes (decoders and encoders) and modalities (language and vision), demonstrating its potential for real-world deployment. We show that a 850M decoder-only MatFormer language model (MatLM) allows us to extract multiple smaller models spanning from 582M to 850M parameters, each exhibiting better validation loss and one-shot downstream evaluations than independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can lead to significant reduction in inference latency.

Live content is unavailable. Log in and register to view live content