Теперь Кью работает в режиме чтения

Мы сохранили весь контент, но добавить что-то новое уже нельзя
CTO Intento, ранее руководитель разработки Яндекс-...  · 25 янв 2022

Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, Vedant Misra
Недавняя (на самом деле уже не совсем, она была на ICLR 2021 на 1st Mathematical Reasoning in General Artificial Intelligence Workshop) прикольная работа от OpenAI из серии про природу вещей, вернее про природу обучения и генерализации в нейросетях.
Авторы продемонстрировали прикольный феномен, который они назвали “grokking” (“врубание” или “схватывание”? Кстати, я не знал, что это слово придумал Хайнлайн, https://en.wikipedia.org/wiki/Grok), когда нейросеть резко переходит от качества случайного угадывания к идеальному качеству, причём случается это сильно после точки оверфиттинга (когда на обучающем датасете уже идеально, а на тестовом всё фигово).
Классный пример работы, когда фундаментально интересный результат можно обнаружить на весьма скромных ресурсах без всяких там кластеров и суперкомпьютеров, хоть на домашнем GPU.
Что делали?
Авторы взяли алгоритмически сгенерированный датасет с бинарными операциями вида a∘b=c, где все эти “a”, “b”, “c”, “=” и “∘” — это просто токены, а сами операции — это различные варианты бинарных операций в модулярной арифметике (по модулю 97, но не уверен, только или ещё и по другим модулям тоже) типа сложения/вычитания, умножения/деления, суммы квадратов и т.п.
Каждое из 97 чисел кодируется отдельным дискретным символом, ни про какую внутреннюю структуру элементов сеть ничего не знает и должна выучить их свойства из данных.
На датасете таких вот примеров обучается маленький декодер трансформера с causal attention masking, который должен предсказывать результат операции, то есть по факту он должен заполнить пропуски в таблице бинарной операции. Обучаются на разных пропорциях обучающей выборки относительно всего датасета. Итоговая точность предсказания считается только по части, соответствующей правой части уравнения.
В трансформере всего 2 слоя, 4 головы внимания, размер эмбеддинга 128, это даёт около 400K параметров без учёта эмбеддингов. По дефолту обучают AdamW.
Собственно главный феномен в том, что точность на обучающей выборке довольно быстро добирается до близкой к идеальной (например, за 1K шагов оптимизации), но на тестовой выборке генерализации не видно до 100K шагов, а в районе 1М точность на тестовом датасете наконец добирается до уровня точности на обучающем.
Этот эффект проверили на разных моделях, оптимизаторах и размерах датасета и более-менее везде на маленьких датасетах он воспроизводится. На более крупных датасетах кривые качества на трейне и тесте следуют друг другу более тесно, а с уменьшением датасета время оптимизации до достижения высокого уровня качества быстро растёт. Для некоторых операций (например, x^3 + xy^2 + y (mod 97)) генерализации в пределах заданного вычислительного бюджета не случается на любом размере датасета вплоть до 95%.
В абляциях пробовали разные интервенции: Adam в различных вариантах — обычный и с полными батчами (точный градиент на всём датасете), с полным батчем и градиентным шумом, с шумом для весов модели, с низким или высоким learning rate, с residual dropout, а также AdamW с двумя разными вариантами weight decay — к началу координат или к инициализации.
Большой эффект на data efficiency оказывает weight decay (L2-лосс на веса модели), особенно wd к началу координат (это и стало потом дефолтом во всех экспериментах), сокращая необходимое число сэмплов более чем в два раза по сравнению с другими изменениями.
Я не до конца понял, опускается ли лосс на трейне прямо до нуля, или он какой-то ненулевой и этого хватает, чтобы понемногу куда-то выбраться по ландшафту приспособленности и перестроить веса сети. Или он прям нулевой и они на одном weight decay добираются до нужной точки. Поскольку это работает (хоть и сильно хуже) и без weight decay, и даже на full-batch Adam (где градиенты типа точные), то я склоняюсь к первому варианту.
Это как-то перекликается с историей про двойной спуск (double descent, про него, кстати, у Коли Михайловского недавно был хороший семинар с первооткрывателем, Михаилом Белкиным, https://ntr.ai/webinar/nauchno-tehnicheskij-vebinar-chemu-uchit-glubokoe-obuchenie/) и особенно с историей про двойной спуск не относительно сложности модели (про что оригинальная работа, https://www.pnas.org/content/116/32/15849.short), а про двойной спуск относительно количества итераций обучения (тоже, кстати, работа OpenAI, https://arxiv.org/abs/1912.02292).
Для тех, кто не следил, в двух словах. Двойной спуск по сложности модели даёт неожиданную картинку поведения validation loss в зависимости от сложности модели — увеличивая сложность модели мы не уходим навсегда в область переобучения (как учит нас классическая теория), а быстро проходим её и начинаем получать loss ещё более низкий, чем в лучшей точке “классического режима”. То есть в большие модели и нулевой лосс на трейне вкладываться полезно, нет от них адского переобучения. Двойной спуск по числу итераций показывает, что и тупо по времени обучения похожая картинка тоже может возникать — обучаете модель дольше, когда валидационный лосс начал уже возрастать, и в какой-то момент он снова начинает уменьшаться (привет early stopping’у, не позволяющему многим дойти дотуда). Про это, кстати, у Коли тоже был вебинар с Дмитрием Ветровым, рекомендую (https://ntr.ai/webinar/nauchno-tehnicheskij-vebinar-neobychnye-svojstva-funkczii-poter-v-glubinnom-obuchenii/).
Всё это контринтуитивные истории, теоретическая база под которые ещё не подведена. Вот с гроккингом тоже что-то похожее. И более того, двойной спуск по числу итераций в данной работе также показали.
Дополнительно повизуализировали матрицы весов выходного слоя, увидели там отражение структур собственно стоящих за этим математических объектов. Это в принципе прикольное направление, может будет полезно для исследования новых математических объектов и получения интуиции относительно них.
Такие вот дела. Всем успешного гроккинга!
Машинное обучение+2