The Triton Transpiler

Introduction

The Triton transpiler is a component of the Mirage framework that converts kernel graphs into Triton code for efficient GPU execution. It provides automatic code generation and validation based on mirage’s built-in optimizations.

Key Features

  • Automatic dimension adjustment for Triton’s power-of-2 and minimum-16 requirements

  • Mask generation for handling limits for dimensions

  • Built-in profiling capabilities and shared memory validation

  • Support for various tensor operations

Limitations and Solutions

Power-of-2 Dimension Requirements

Triton requires tensor dimensions to be powers of 2 for optimal performance. The transpiler handles this by:

  1. Automatically padding dimensions to the next power of 2

  2. Generating masks to ensure correct computation on actual data

Matrix Multiplication Constraints

We use Triton’s built-in matmul API ``tl.dot` <https://triton-lang.org/main/python-api/generated/triton.language.dot.html#triton.language.dot>`__ to perform matrix multiplication. The tl.dot operation in Triton requires minimum dimensions of 16. The transpiler handles this by:

  1. Using the closest power of 2 for optimal performance while ensuring dimensions are at least 16 through padding

  2. Generating appropriate masks for the padded regions

Shared Memory Limitations

The transpiler validates shared memory usage and handles excessive requirements by simply generating the temporary code as well and invalidating it when finding it’s not executable.

Profiler Implementation

TritonProfiler

The profiler component:

  1. Make temporary files of generated codes

  2. Measures average execution time to evaluates multiple kernel variants

  3. Filters out invalid kernels

  4. Selects the best performing variant to return

Key profiling steps:

for idx, g in enumerate(all_graphs):
    if not g.valid_kernels():
        print("muGraph {}: skipping since its shared memory usage exceed limit".format(idx))
        continue
    # Warmup runs
    for _ in range(16):
        g(inputs=input_tensors)
    torch.cuda.synchronize()
    starter.record()
    for _ in range(1000):
        g(inputs=input_tensors)
    ender.record()
    torch.cuda.synchronize()
    perf = starter.elapsed_time(ender) / 1000

Usage Guide

Basic Usage

By using graph.superoptimize and indicate backend to “triton”, mirage will automatically search, profile and return the best arrangement it found and return the related graph.

You could then use generate_triton_program()["code"] to get the triton code and output it to a file for further use.

Here’s a simple example of using the Triton transpiler, which you could find as demo/triton_rms_norm.py:

import mirage as mi
import torch

# Create a kernel graph
graph = mi.new_kernel_graph()

# Define inputs and operations
X = graph.new_input(dims=(16, 4096), dtype=mi.float16)
W = graph.new_input(dims=(4096, 6144), dtype=mi.float16)
D = graph.rms_norm(X, normalized_shape=(4096,))
O = graph.matmul(D, W)
graph.mark_output(O)

# Optimize and generate Triton code
optimized_graph = graph.superoptimize(config="mlp", backend="triton")

# Generate and save the code
with open("triton_generated.py", "w") as f:
    f.write(mi.generate_triton_program(
        optimized_graph.cygraph,
        target_cc=10)["code"])

Advanced Guide

Debug Mode

Inside python/mirage/kernel.py, you could find triton_transpiler related code at:

elif backend == "triton":
            return profile_and_select_best_graph(all_graphs, target_cc=torch.cuda.get_device_properties(0).major * 10 + torch.cuda.get_device_properties(0).minor, warmup_iters=16, profile_iters=1000, debug_mode=False)

By indicating debug_mode as True, you could get a more detailed output of every single code generated by searching process while code files will also be made under the directory, which will be useful to discover potential falses for these transpiled triton codes.

Iteration Times

You could also modify the times for warmup iteration and profiling iteration by changing related parameters in the function above. By default the numbers are 16 and 1000.