Архитектура важнее размера: внедряем каузальные свертки в трансформер и получаем связный сторителлинг

Страницы:  1

Ответить
 

Professor Seleznov


Дело было вечером, делать было нечего. Я сидел за ноутом и разбирал новую идею Deepseek Engram: Лян Ванфень собрал вместе хеш-таблицы и почти-линейный трансформер - получилось дешево и сердито.
Однако есть в Engram один недостаток - он требует много RAM (каламбурчик, хаха). А хотелось архитектуру, на инференс которой не придется скидываться всем поселком.
Небольшой ликбез
Engram, по сути, перешивает токены и добавляет к ним факты. Реализовано это довольно хитро, через хеш-функцию, O(1) по сложности. Благодаря такой пристройке трансформер уделяет больше внимания на грамматику и связь слов в предложении.
Основная идея
А что если вместо дорогого по вычислениям Engram взять простые свертки? Они дешевые, быстрые и могут запомнить базовые факты.
Именно об этом я и подумал. И тут же решил проводить тесты.
К сожалению у меня нет в гараже кластера на 8xH200 (да и гаража у меня нет), поэтому обучить что-то большое не получится. Однако для быстрого эксперимента хватит Colab и его Т4 16Гб.
Архитектура модели
pic
Слой
За пару минут набросал схему в Obsidian. Теперь про каждый блок отдельно
RMSNorm
Базовый слой нормализации, в современный трансформерах без него будет тяжко.
Conv1D
Ключевоенововведение.Depthwise и kernel = 3 обогащают токены и перемешивают их. Чтобы сетка не 'поглядывала' реализовал каузальные свертки.
pic
Визуализация MQA
MQA
Довольно быстрая и дешевая реализация классического Self-Attention, но все еще не линейная или реккурентная архитектура.
FFN + SwiGLU
Два главных компонента: новая функция активации и необычное расширение в линейном слоев - x8/3 на 3 слоя вместо устоявшегося х4 на 2 слоя (позволяет сохранить то же кол-во параметров при большем кол-ве операций).
Эта комбинация отлично показала себя в моделях Llama, где была применена впервые.
Все это решил обозвать NormIs-1. Логики в названии нет абсолютно никакой.
Меньше слов - больше кода
Не стал что-то менять в нормализации и сделал самую простую версию.
сlass RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.g
Так-же сделал с FFN - просто и понятно
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
hidden_dim = int(dim * 4 * 2 / 3)
self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
self.w_val = nn.Linear(dim, hidden_dim, bias=False)
self.w_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
gate = F.silu(self.w_gate(x))
val = self.w_val(x)
return self.w_out(gate * val)
Наивная реализация сверток. Спойлер - простой forward() потом вышел мне боком из-за медленной памяти.
class CausalConv1D(nn.Module):
def __init__(self, dim, kernel_size=3):
super().__init__()
self.pad = kernel_size - 1
self.conv = nn.Conv1d(dim, dim, kernel_size, groups=dim)
def forward(self, x):
x = x.transpose(1, 2)
x = F.pad(x, (self.pad, 0))
x = self.conv(x)
x = x.transpose(1, 2)
return x
А вот и все ноутбуки с обучением (ссылки на Colab):
Кастомная архитектура
MHA + MQA
Метрики
Один из самых важных вопросов - а как вообще оценить NormIs-1? С чем его сравнивать? Какие метрики измерять?
Введем двух дополнительных кандидатов - трансформер на MQA и на MHA без сверток.
MHA считается лучшим по качеству, но он-же медленнее всего. Это Topline
MQA - топ по скорости, но может терять в качестве. Это Baseline.
pic
Архитектура слоев у двух дополнительных кандидатов
Метрики 'интеллекта' модели - Loss (Cross-Entropy) и Perplexity. Метрики скорости - время обучения и TPS (tokens per second).
Моя цель - усидеть на двух стульях: получить интеллект уровня MHA, не потеряв при этом в скорости генерации MQA. Если NormIs-1 догонит Topline по качеству, оставшись таким же быстрым - это победа.
Сравнение
Чтобы эксперимент был честным, я зафиксировал все гиперпараметры. Изменялась только архитектура внутреннего блока.
Конфигурация:
  • Датасет: TinyStories. Идеален для микро-моделей: в нем простая лексика, но строгие требования к грамматике и логике.
  • Токенизатор: Свой собственный, обученный на 8К токенов. Это позволило не раздувать матрицу эмбеддингов и сфокусировать 'мозги' модели на смысле, а не на хранении словаря.
  • Геометрия: model_dim = 128context = 256. Компактно, но достаточно для коротких рассказов.
  • Обучение: steps = 5000batch = 64.
Итого на претрейн -
pic
токенов.
Запустил обучение и ушел пить чай. По моим расчетам каждая модель училась бы не более получаса.
pic
прочитать с соответствующей интонацией =)
И вот наступил момент Х, пора сравнивать.
Сравнение MHA Topline MQA Baseline NormIs-1
Параметры 1.84M 1.75M 1.75M
Время обучения 24:04 24:15 25:03
Val. Perplexity 7.9 8.24 7.94
Val. Loss 2.0668 2.1095 2.0713
Tokens/sec 362 339 202

Качество довольно хорошее. NormIs остался на уровне MHA, имея меньше параметров.
Но вот скорость обучения и инференса выглядит печально. А все из-за наивной реализации сверток. Граф вычислений на PyTorch должен создавать новый CUDA Kernel для каждой свертки.
Из-за этого модель значительно медленнее при инференсе, а при обучении это не так заметно. Думаю, если написать нормальный движек, то NormIs получит свои 300т/с+
Вот ссылка на папку - там графики падения лосса и примеры генерации модели.
Выводы и идеи
Результат хороший, но не прорывной. Дальше я хочу попробовать эту же конфигурации, но на большем масштабе (20М+ параметров) и на сложной задаче (например, Fineweb-Edu).
Спасибо что дочитали статью. Это мой первый опыт написания подобных текстов.
Буду рад если получится дать фидбек на мои решения. Я в ML недавно, только учусь. Будет интересно послушать людей с опытом.-Источник
 
Loading...
Error