:tocdepth: 1 ************************************** Superoptimizing Multi-Latent Attention ************************************** Introduction ============ .. code-block:: Python import torch import mirage as mi graph = mi.new_kernel_graph() Q = graph.new_input(dims=(2, 256, 64), dtype=mi.float16) K = graph.new_input(dims=(2, 64, 4096), dtype=mi.float16) V = graph.new_input(dims=(2, 4096, 64), dtype=mi.float16) A = graph.matmul(Q, K) E = graph.exp(A) S = graph.reduction(E, 2) D = graph.div(E, S) O = graph.matmul(D, V) graph.mark_output(O) optimized_graph = mi.superoptimize(graph)