Теперь Кью работает в режиме чтения

Мы сохранили весь контент, но добавить что-то новое уже нельзя
CTO Intento, ранее руководитель разработки Яндекс-...  · 24 дек 2021

Pay Attention to MLPs

Hanxiao Liu, Zihang Dai, David R. So, Quoc V. Le
Код (неофициальный): https://github.com/lucidrains/g-mlp-pytorch
В области в последние полгода происходит большое бурление. По инерции всё ещё кажется, что фронтир многих исследований в computer vision и nlp уже (или соответственно всё ещё) находится в трансформерах. И там действительно много всего происходит, так что следить за всем потоком публикаций уже просто нереально и надо держать небольшой НИИ (а лучше автоматического агента), чтобы разбирать только свежие работы. Но это всё же не совсем уже фронтир. А вот старые добрые многослойные персептроны таки фронтир!
Буквально через пару недель после MLP-Mixer (https://t.me/gonzo_ML/776) и кучки других практически одновременно вышедших похожих работ появилась новая интересная работа на аналогичную тему.
Авторы снова подходят со стороны замены трансформера на варианты полносвязных сетей, экспериментируют с разными вариантами MLP и выбирают одну наиболее эффективную конфигурацию. Полученная модель называется gMLP и её проверяют на задачах классификации картинок ImageNet и на обучении BERT’а.
Собственно авторы сделали следующее. Как и трансформер, сеть состоит из набора одинаковых блоков, на вход которых прилетают эмбеддинги из предыдущего слоя (или входные), а на выходе вылетают обработанные эмбеддинги той же размерности. Размерность входа n*d, где n — число токенов (или длина последовательности), а d — размер эмбеддинга (ну плюс ещё подразумеваемое измерение для батча).
Модель по входам и выходам совместима с BERT/ViT.
Внутри нет никаких блоков self-attention, а только нормализация, channel projection, активация (GELU), блок spatial projection и снова channel projection на выходе
Channel projection — это обычные линейные проекции как в FFN-слоях трансформеров, в конфигурации аналогичной BERT_base это 768x3072 и 3072x768 (на входе и выходе соответственно, причём на выходе скорее первое измерение в два раза меньше указанного, потому что там хитрый сплит по каналам и гейтинг, о которых ниже).
Spatial projection — это самое мясо gMLP, слой, который устраивает взаимодействие между различными токенами (а не каналами). Если этот слой заменить на identity, то из блока получится обычная FFN, где каждый токен обрабатывается независимо и друг с другом они не взаимодействуют.
Целью работы было найти такой вариант этого блока, который позволит заложить в систему сложные взаимодействия между токенами. В самом эффективном из вариантов, который далее в работе и используется, этот блок сначала делает разделение каналов (split) на два потока одинакового размера, которые в конце поэлементно перемножатся. Первый поток без изменений отправляется к выходу, а второй сначала прогоняется через нормализацию, а затем линейную проекцию размерности n*n (где как раз для каждого элемента можно определить влияющие на него другие элементы), и далее эти два потока поэлементно (по измерению каналов) перемножаются, то есть имеем некий gating, когда один вход выбирает, что возьмём из другого. Чтобы при обучении это не взорвалось, данную функцию инициализируют так, чтобы на старте это была практически identity трансформация. Хороший разбор кода есть тут https://nn.labml.ai/transformers/gmlp/index.html.
Данный блок называется Spatial Gating Unit (SGU). Это напоминает гейтинг из LSTM/GRU или скорее даже гейтинг из Highway Networks (https://arxiv.org/abs/1505.00387) или GLU (https://arxiv.org/abs/1612.08083), только в отличие от последних он вычисляется не по проекции скрытого измерения (которое здесь называется измерением каналов), а по проекции пространственного (которое отвечает за cross-token interactions). По мне так ещё одно существенное отличие от гейтинга в том, что обычно выход гейта нормализован (например, через сигмоиду) так, чтобы он был в диапазоне [0,1], здесь же, хоть на входе нормализация и есть, далее работает фактически полноценная multiplicative interaction, где в принципе итоговый вес может после линейного преобразования быть любым, и ещё вопрос, кто там в итоге кого гейтит.
Важно по сути, что модель в отличие от трансформеров вообще не использует позиционные эмбеддинги, и кроме того функция смешивания токенов не зависит от входных репрезентаций (как это есть в трансформерах, где она динамически генерируется из входных данных через механизм внимания).
Полученную модель проверяют на картинках и текстах. На картинках решают задачу классификации ImageNet, аналогичным ViT образом, когда картинка конвертируется в набор патчей 16x16 пикселей. Авторы обнаружили, что gMLP любит переобучаться, поэтому в него добавили разных регуляризация по аналогии с более производительным DeiT (улучшенный ViT, если кто не следил, https://arxiv.org/abs/2012.12877). Модели подобрали примерно соответствующие по числу параметров ViT/DeiT.
Результаты хороши. SoTA для картинок ожидаемо не побита (там, конечно, есть свои зверски затюненные герои из свёрточного мира), но gMLP ощутимо бьёт ViT и примерно соответствует DeiT. Ну то есть модель без self-attention прекрасно справляется с изображениями. А также gMLP бьёт недавно описанные MLP-Mixer/ResMLP (https://t.me/gonzo_ML/776).
Предыдущие MLP пытались заменить ViT, но не шли в NLP, а авторы текущей работы идут. На задаче обучения BERT-подобной модели попробовали разные варианты блоков MLP и собственно пришли к варианту multiplicative + split. Модель получилась сравнимая с BERT_base, и намного лучше его же с отключенным механизмом внимания (что, видимо, заслуга SGU). Также gMLP ощутимо лучше MLP-Mixer.
Для BERT-подобной модели сравниваются по трём метрикам: perplexity самой предобученной модели, а также результаты файнтюнинга модели на две задачи из GLUE — SST-2 и MNLI.
На предобучении gMLP сначала отстаёт от трансформера по perplexity, но потом при увеличении размера модели начинает обгонять. Степенной закон в целом довольно близкий к трансформеру.
Веса spatial projections повизуализировали и нашли разные красивые выученные фильтры. Также для задачи MLM (masked language model) обнаружили, что gMLP выучивает Тёплицевы матрицы в качестве весов линейной проекции в SGU, что является выученной из данных инвариантностью к перемещению (неважно, в какой позиции стоял токен, который надо заполнить). Это позволяет также более экономно хранить эти матрицы в модели.
С файнтюнингом интересно, на SST-2 gMLP всегда обгоняет трансформер, а на MNLI всегда отстаёт. Авторы делают вывод, что для задач MNLI у трансформера более подходящий inductive bias, и так как в этих задачах необходимо работать с двумя предложениями (вместо одного как в SST-2), то видимо gMLP не хватает способности глядеть на соседние предложения.
Поэтому сделали ещё один вариант модели под названием aMLP (“a” потому что attention), где в блок SGU добавили одну единственную голову self-attention размера 64. Её выход плюсуется к выходу spatial projection до перемножения с другим стримом внутри блока.
aMLP берёт лучшее из двух миров и бьёт трансформер везде. Получается, что inductive bias механизма внимания и spatial gating несколько разные и дополняют друг друга.
Короче, получается, что в NLP можно в общем-то и без трансформеров, старыми добрыми MLP. Намёки на это были и за год до описанных работ (например, https://arxiv.org/abs/2005.13895). Были, правда, движения и в обратную сторону, чтобы наоборот все FFN заменить вниманием (https://arxiv.org/abs/1907.01470).
Официального кода нет, но есть неофициальный. Из интересного, есть реализация GPT на базе gMLP (PyTorch: https://github.com/lucidrains/g-mlp-gpt, JAX: https://github.com/lucidrains/mlp-gpt-jax).
Источник: https://t.me/gonzo_ML/787
Машинное обучение+2