Superoptimizing Attention with QK Normalization
Introduction
Recent work has introduced several variants to attention, include group-query attention, multi-latent attention, and group-query attention with query-key normalization (https://arxiv.org/abs/2405.09818). The following code snippet demonstrate the attention kernel used in Chameleon, a multi-modal foundation models for understanding and generating images and text in any arbitrary sequence. The key difference is that Chameleon performs normalizes query and key before computing attention.
import torch
import mirage as mi
graph = mi.new_kernel_graph()
Q = graph.new_input(dims=(2, 256, 64), dtype=mi.float16)
K = graph.new_input(dims=(2, 64, 4096), dtype=mi.float16)
V = graph.new_input(dims=(2, 4096, 64), dtype=mi.float16)
Q = graph.rms_norm(Q)
K = graph.rms_norm(K)
A = graph.matmul(Q, K)
E = graph.exp(A)
S = graph.reduction(E, 2)
D = graph.div(E, S)
O = graph.matmul(D, V)
graph.mark_output(O)
optimized_graph = mi.superoptimize(graph)