Superoptimizing Low-Rank Adaptation

Introduction

Low-rank adaption (LoRA) has been widely used to adapt a pre-trained model to specialized domains and tasks. These LoRA adapters are generally inserted into the linear layers of a model, leveraging the low-rank decomposition concept to construct trainable parameters inserted into the original model weights. The computation graph for LoRA is shown as follows.

Computation Graph for LoRA

The following code snippet demonstrates superoptimizing LoRA adapters in Mirage.

import mirage as mi
graph = mi.new_kernel_graph()
X = graph.new_input(dims=(16, 256), dtype=mi.float16)
W = graph.new_input(dims=(256, 4096), dtype=mi.float16)
A = graph.new_input(dims=(256, 16), dtype=mi.float16)
B = graph.new_input(dims=(16, 4096), dtype=mi.float16)
D = graph.matmul(X, A)
E = graph.matmul(D, B)
C = graph.matmul(X, W)
O = graph.add(C, E)
opt_graph = graph.superoptimize()
LoRA performance

The figure compares the relative performance of the LoRA kernels generated by torch, torch.compile, and Mirage (higher is better) on NVIDIA A100 GPUs. We observe that using torch.compile actually results in a 2x slow down for LoRA. The kernel discovered by Mirage is 1.6x faster than torch. The following figure shows the best uGraph discovered by Mirage for LoRA, which fuses the three MatMuls and the subsequent Add into a single kernel. Mirage reorganizes the computation into two thread-block level MatMuls by leveraging the following algebraic transformation: W x X + B x A x X = (W | B) x (X | (A x X)). The concats do not involve any computation and are performed by updating tensor offsets in GPU shared memory.

uGraph for LoRA