Прикольная и малоизвестная работа с NeurIPS 2019 года. Авторы заходят с проблемы катастрофического забывания при мультизадачном обучении, но решение предлагают очень интересное.
В двух словах идея работы в том, что в одном наборе параметров можно закодировать множество моделей так, что они будут существовать в некоторой суперпозиции, не особенно мешая друг другу, и при необходимости (когда нужно применить модель под конкретную задачу) можно извлекать их индивидуально.
Работа исходит из наблюдения, что сверхпараметризованные модели содержат избыточное количество весов, большинство из которых после обучения можно удалить. Но с нуля маленькую сеть почему-то обычно обучать до аналогичного уровня качества не получается (привет гипотезе лотерейного билета). Зато, как выясняется, эту избыточность можно частично использовать иным способом — обучить одну сеть с числом параметров L на множество задач K и таким образом уменьшить эффективное количество параметров на задачу до O(L/K). Это, наверное, применимо и к обычному multi-task обучению (к той же ExT5, https://t.me/gonzo_ML/761), но конкретно здесь предлагается весьма любопытный механизм.
Для каждой задачи k выучивается свой набор параметров W_k, далее все такие параметры вместе зашиваются в суперпозицию моделей. К нужной в соответствии с задачей модели обращаются с помощью контекста C_k, который динамически “отправляет” входные данные с соответствующую модель, вытащенную из суперпозиции. В такой постановке параметры W итоговой модели являются аналогом “памяти”, а контекст C_k аналогом ключа для доступа к специфическим параметрам W_k. Эта интерпретация вдохновлена работами Пентти Канервы про ассоциативную память и hyperdimensional computing. Ну и вообще авторы текущей работы тесно работают с Канервой в Redwood Center for Theoretical Neuroscience в UC Berkeley.
Для мультизадачных сетей (а также и для всех кейсов, где распределения входных и выходных данных меняются, включая online learning) существенна проблема катастрофического забывания. Например, если учить модель последовательно на много задач, то к концу обучения модель почти всё забудет про первую задачу. Такого хочется избегать и существует множество наработок в этой теме. Хоть бы даже старый добрый replay buffer, активно используемый в RL. Но суперпозиция — это прям совершенно другой и новый подход.
Авторы называют свой метод PSP (Parameter Superposition). Они исходят из допущения, что входные данные являются сравнительно низко-размерными относительно всего пространства (что в целом кажется верно, вроде как есть консенсус про то, что, например, реальные картинки образуют некое низкоразмерное многообразие в более многомерном пространстве всех потенциально возможных картинок).
Фундаментальная операция, выполняемая сетью, это умножение входов x (из R^N) на матрицу весов W (из R^(M*N)), то есть y = Wx. Сверхпараметризация (over-parametrization) сети подразумевает, что только маленькое подпространство, охватываемое строками матрицы W в R^N, релевантно задаче.
Пусть W_1, W_2, …, W_K — это наборы параметров для каждой из K задач. Если для каждой из задач требуется только небольшое подпространство в R^N, то каждую матрицу W_k можно трансформировать с помощью задаче-специфичного линейного преобразования C^{-1}_k (это и есть контекст), так что строки каждой из итоговых матриц W_k*C^{-1}_k займут взаимно ортогональные подпространства в R^N. А раз так, то все эти преобразованные параметры можно сложить вместе в итоговую матрицу W и они не будут друг с другом интерферировать.
Получить эти параметры назад под конкретную задачу k можно с помощью обратного (к обратному C^{-1}_k :) ) преобразования с использованием контекста C_k. Полученные таким образом веса W_k’=W*C_k будут шумным вариантом W_k, так что W_k’*x = W_k*x + epsilon. Ну и вроде как можно добиться низкого значения этого шума.
Поскольку матричное умножение ассоциативно, y_k = (W*C_k)*x можно переписать как y_k = W*(C_k*x), то есть мы умножаем каждый входной вектор на соответствующую C_k, а затем уже умножаем на матрицу W. Для свёрток, кстати, разумнее делать наоборот, сначала умножать C_k на W, потому что весов там обычно меньше, чем входных данных.
Для кейса, когда обратная матрица к C_k это транспонированный вариант C_k (С^{-1}_k = C^T_k) мы имеем матрицы, задающие вращения. В общем в этом фреймворке самая соль — это придумать, как вращением перемещать входные данные x в ортогональные подпространства внутри R^N. Эти ортогональные вращения просто сделать для свехрпараметризованного входа (и пожалуй должно быть верно и для сверхпараметризованных сетей, где вход для последующих слоёв это уже выход предыдущего слоя, а он как раз уже тоже over-parameterized).
Среди вариантов таких вращений предлагаются:
1) Вращательная суперпозиция (pspRotation): случайная ортогональная матрица из распределения Хаара (https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ortho_group.html). Вещь универсальная, но тяжёлая ибо для матрицы M*M требует M^2 параметров. Число параметров можно уменьшить, если перейти к более ограниченным блочно-диагональным и диагональным матрицам.
2) Комплексная суперпозиция (pspComplex): сделать c_k векторами комплексных чисел, где каждый компонент вектора сэмплится равномерно с единичной окружности с фазой от [-pi, +pi], это даёт диагональную матрицу и M параметров (вернее, кажется, 2M).
3) Степени одного контекста (pspOnepower): взять целые степени одного вектора контекста, что даёт всего один доп.параметр на задачу.
4) Бинарная суперпозиция (pspBinary): использовать бинарные вектора контекста {-1, 1}, которые являются специальным вариантом комплексной суперпозиции с двумя разрешёнными значениями фазы {0, pi}.
Применение суперпозиции добавляется к линейным преобразованиям в каждом слое:
x^(l+1) = g(W^(l)*( c(k)^(l) ⊙ x^(l) )), где g() — нелинейность, например, ReLU.
Проверяют эти идеи на кейсах, когда распределения входных или выходных данных изменяются со временем.
В случае изменяющегося входного распределения взят Permuted MNIST, где каждые 1000 итераций производится случайная перестановка пикселей (метки классов не меняются) и это даёт новую задачу, для которой выбираются новые параметры контекста: для бинарной суперпозиции — случайный бинарный вектор, для комплексной — случайное число (так в статье, хотя по идее для комплексной это тоже вектор, а число только для варианта pspOnepower), а для вращательной — матрица. Всего итераций 50К и это соответствует 50 задачам. Качество измеряется на первой задаче. Если модель её быстро забывает, то с новыми задачами качество заметно деградирует.
Обучались полносвязные сети с двумя скрытыми слоями. Оказалось, что по сравнению с бэйзлайном, суперпозиция моделей работает намного лучше. И с увеличением размера слоя деградация уменьшается (что логично, размерность пространства растёт, проще получать ортогональность).
В зависимости от типа суперпозиции качество тоже разное. Вращательная суперпозиция (pspRotation) ожидаемо даёт самое высокое качество, но требует очень много дополнительной памяти. Комплексная (pspComplex) работает лучше бинарной (pspBinary), что тоже ожидаемо, ибо бинарная это частный случай комплексной. Степени одного комплексного числа (pspOnepower) неожиданно хорошо работают, на графике они следующие по качеству после полной вращательной суперпозиции, но в тексте на этом результате почему-то не акцентируются. Два метода борьбы с катастрофическим забыванием из других статей тоже побили.
Также проверили на задачах, где включали постепенное вращение MNIST или Fashion MNIST, так что полный оборот происходит за 1000 шагов, а контекст меняется каждые 100 шагов (каждый шаг его менять не надо, к небольшим поворотам сеть устойчива), итого получается 10 задач. Оцениваются всегда на отсутствии вращения. Стандартная сеть безумно осциллирует с сильными просадками качества при уходе от нулевого вращения, а все PSP модели ведут себя очень достойно и не деградируют настолько сильно.
Отдельная интересная задача — это выбирать контекст как-то автоматически, не предоставляя эту информацию явно в сеть. Здесь попробовали генерить его на каждый поворот, то есть за цикл из 1000 шагов генерируется 1000 новых случайных контекстов, а далее они переиспользуются в новых циклах (этот способ назвали pspFast). Получилось лучше бэйзлайна, но хуже предыдущих моделей, что в общем ожидаемо — объёмы хранимого возросли в 100 раз и каждая отдельная сеть в итоге обучалась меньше времени. Также попробовали более хитрый вариант под названием pspFastLocalMix, где в каждый момент времени вектор контекста это микс фаз в соседних временных точках. Это получилось получше, чем pspFast, то есть получается можно в модель инкорпорировать грубую информацию о нестационарности входного распределения. Удобно, когда детальной информации нет, но некоторые свойства изменений входного распределения известны. В идеале вообще было бы круто находить контекст тем же backprop’ом, что по идее для операторов контекста с непрерывной топологией, например, для комплексной, должно работать. Не видел пока продолжения этой темы.
В задачах с изменяющимся выходным распределением взяли хитрый вариант incremental CIFAR, это когда к обычному CIFAR-10 добавляются в качестве новых задач случайные другие 10 классов из CIFAR-100. Здесь тупой бейзлайн (когда выходные слои остаются старыми) не очень хорош, поэтому сделали более умный, когда для новой задачи обучается новый выходной слой. Это проверили на ResNet-18 и всё равно здесь pspBinary оказался намного лучше, чем более хитрый бэйзлайн (причём в pspBinary выходной слой был таки один, универсальный).
Отдельная интересная тема в supplementary про композицию контекстов. Для контекстов возникает своя алгебра и новые контексты можно создавать из имеющихся с помощью заданной операции. Например, можно делать композицию двух контекстов в новый, или также можно делать mixture of contexts, скажем, усреднением по окну. Ну и, кстати, pspOnepower и был как раз одним из вариантов композиции, когда один и тот же контекст возводили в разные степени.
Интересная работа, короче. Даёт много пищи для размышлений. Во-первых, это красиво. Во-вторых, много с чем перекликается и даёт возможность посмотреть по-новому на старые вещи. Вот, например, dropout наверняка через суперпозицию тоже как-то выражается и объясняется. В-третьих, это в русле работ про hyperdimensional computing, что очень интересная, но почему-то недостаточно обсуждаемая область. Суперпозицию можно также рассматривать как альтернативу сжатию сетей, хотя это само по себе мне менее интересно. Ну и с квантовыми нейросетями, наверное, вообще круто будет.