Superoptimizing Gated MLP
Introduction
Gated MLP layers are currently used in many large language models (e.g., LLAMA-2, LLAMA-3, and their varients). The computation graph for gated MLP is shown as follows.
import torch
import mirage as mi
graph = mi.new_kernel_graph()
X = graph.new_input(dims=(8, 4096), dtype=mi.float16)
W1 = graph.new_input(dims=(4096, 4096), dtype=mi.float16)
W2 = graph.new_input(dims=(4096, 4096), dtype=mi.float16)
D1 = graph.matmul(X, W1)
D2 = graph.matmul(X, W2)
O = graph.mul(graph.silu(D1, D2))
graph.mark_output(O)
optimized_graph = mi.superoptimize(graph)
input_tensors = [
torch.randn(8, 4096, dtype=torch.float16, device='cuda:0'),
torch.randn(4096, 4096, dtype=torch.float16, device='cuda:0'),
torch.randn(4096, 4096, dtype=torch.float16, device='cuda:0')
]
optimized_graph(input_tensors)
In the above code snippet, we first construct a kernel graph in Mirage that corresponds to the gated MLP computation and superoptimize the graph. The optimized graph returned by Mirage can be directly called as a function, which launches the optimized kernels for gated MLP.