Любопытная работа. По названию я ожидал, что будет что-то из серии про отказ от бэкпропа и/или про стремление к более биологически-адеквантым подходам, может быть что-то типа синтетических градиентов (крутая работа, кстати, если кто не знаком: https://deepmind.com/blog/article/decoupled-neural-networks-using-synthetic-gradients). Но неожиданно это оказалось про заход вообще с другой стороны, где первое и второе как бы даже верно, но используется при этом всё тот же autodiff, и градиент получается примерно тот же самый, но без бэкпропа. С форвардпропом.
В современных фреймворках практически везде сейчас используется автоматическое дифференцирование (autodiff, AD), которое позволяет писать код нейросетей (причём не только статическое перемножение матриц, но и произвольные математические операции и даже управление потоком выполнения с ветвлением и условиями) и не думать про производные. Это, конечно, существенное упрощение жизни. Кто помнит, что было лет 10 назад (или проходил курс Хинтона в 2012-м), не может не оценить. Тогда чтобы запрограммировать сеть и реализовать бэкпроп, надо было просчитать вручную производные ваших слоёв, что не всегда было тривиально и часто вело к ошибкам. Сейчас, благодаря AD, пиши что хочешь, компьютер сам посчитает.
Если вкратце про AD, то его важно различать с символьным дифференцированием и с численным дифференцированием. Второе в DL обычно не вариант (медленное и неточное), использовалось разве что для проверки, что вы правильно посчитали производные. Символьное вроде кое-где использовалось (я, кстати, думал, что Theano было на нём, но похоже, что оно всё же тоже на AD, хоть и называли его символьным), но сейчас я знаю разве что SymPy. Современные фреймворки, кажется, все на AD.
У AD есть два варианта: reverse-mode AD (в своём частном случае известный как backprop) и менее известный forward-mode AD.
Forward-mode позволяет получить частные производные каждого из выходов по конкретному входу за один проход и эффективен, когда входов меньше, чем выходов. В общем виде он позволяет посчитать произведение якобиана на вектор (JVP, Jacobian-vector product), где единичный вектор на какой-то оси даст частную производную по этой координате, а произвольный вектор даст производную по направлению этого вектора, directional derivative. Что круто, и значение самой функции, и jvp вычисляются за один проход.
Reverse-mode эффективен, когда входов больше, чем выходов. Этот вариант двухфазный, когда сначала надо сделать forward pass, то есть прогон значений через функцию, запомнить по ходу дела промежуточные вычисления, а затем на обратном проходе посчитать градиенты. Этот режим вычисляет произведение вектора на якобиан (VJP, vector-Jacobian product) и для единичного вектора это собственно полноценный градиент, то есть частные производные по всем входам.
Обе эти операции производятся без необходимости прямого вычисления всего якобиана.
Важно, что forward и reverse режимы вычисляют различные вещи, собственно производную по заданному вектором направлению или полный градиент. Reverse режим обычно более дорогостоящий, потому что, во-первых, два прохода, а во-вторых надо сохранять результаты операций по ходу дела (и для этого используются все эти gradient tape или что-то подобное, в pytorch вроде как это через граф делается, но пишут они про это своеобразно — то tape, то не tape: https://discuss.pytorch.org/t/is-pytorch-autograd-tape-based/13992). Но тут правда надо ещё сделать поправку на форму матрицы, чего там больше, строк или столбцов (соответственно входов или выходов).
В текущей работе авторы предлагают метод расчёта градиента на основе производных по направлению через forward-mode AD и называют такую постановку forward gradient. Forward-prop’ом его специально не стали называть, потому что под этим часто подразумевается обычный forward pass как первый этап бэкпропа.
Авторы понятно доказывают, что если брать случайный вектор (той же размерности что и вход) так, что его компоненты не зависят друг от друга и имеют среднее и дисперсию соответственно 0 и 1, то forward gradient является несмещённой оценкой обычного градиента (то есть распределение, из которого сэмплятся эти вектора, многомерное нормальное с нулевым средним и единичной матрицей ковариации). Исследование других распределений оставляют на будущее.
Сам алгоритм спуска по “прямому градиенту” работает аналогично SGD, только лишь вместо настоящих градиентов берутся forward gradients. Назвали FGD (Forward Gradient Descent). Более продвинутые алгоритмы оптимизации типа момента или Adam также оставляют на будущее.
Реализовали это в PyTorch’е с нуля, отключив всю его собственную механику autodiff’а, то есть все тензоры это обычные тензоры с requires_grad=False, и также перегружены операции над тензорами. Сравниваются с обычным торчовым автоматическим дифференцированием (requires_grad=True и backward()). Дополнительными оптимизациями не занимались, так что на будущее задел ещё есть (хотя зачем, если есть JAX :) ).
Интересно выглядят оптимизационные траектории, для FGD они как будто более шумные (хотя я не совсем понимаю, почему они для SGD такие гладкие). Если траектории FGD действительно шумнее (им есть откуда), то это тоже интересно, и может оно даст какую-то дополнительную регуляризацию по аналогии как и сам SGD это делает, и может быть какие-то иные свойства ландшафта проявятся, если подробнее всё поисследовать.
На логрегрессии на MNIST также примерно одинаковое качество, но FGD быстрее. Относительно базового случая (просто вычисление функции, а-ля чистый forward pass без всяких градиентов), FGD медленнее в ~2.4 раза, а SGD медленнее в ~4.4 раза. То есть forward gradient быстрее классического бэкпропа и занимает примерно 55% времени последнего.
Дальше собирают MLP (1024-1024-10 + ReLU) для MNIST. Тут интересно. На низком learning rate всё примерно так же, а на более высоком FGD ещё и по итерациям начинает сходиться быстрее, чем SGD.
На простой CNN для MNIST тоже примерно всё так же.
Отдельно проверили, как эта штука скейлится с увеличением числа слоёв от 1 до 100. В целом скейлится. Относительно бэкпропа время нового алгоритма с увеличением слоёв поднимается от примерно чуть меньше 0.6 до чуть больше 0.8. Интересно, как дальше будет. Но в любом случае на всех этих примерах forward gradients быстрее и не хуже (а кое-где и лучше). При этом есть ещё и куда ускорять.
Следующими шагами, конечно, ожидаются реальные большие сети. Ждём в следующих работах. Можно не дожидаться статьи и провести самостоятельно. Никто не хочет? :)
В целом интересно, как оно себя будет вести в более реальных сложных кейсах с современными архитектурами и оптимизаторами, и особенно интересно с другими распределениями случайных векторов.
Ну и с точки зрения biological plausibility (https://arxiv.org/abs/1807.04587) тоже может оказаться более реальным, чем бэкпроп.