The Triton Transpiler ===================== Introduction ------------ The Triton transpiler is a component of the Mirage framework that converts kernel graphs into Triton code for efficient GPU execution. It provides automatic code generation and validation based on mirage’s built-in optimizations. Key Features ------------ - Automatic dimension adjustment for Triton’s power-of-2 and minimum-16 requirements - Mask generation for handling limits for dimensions - Built-in profiling capabilities and shared memory validation - Support for various tensor operations Limitations and Solutions ------------------------- Power-of-2 Dimension Requirements ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Triton requires tensor dimensions to be powers of 2 for optimal performance. The transpiler handles this by: 1. Automatically padding dimensions to the next power of 2 2. Generating masks to ensure correct computation on actual data Matrix Multiplication Constraints ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We use `Triton’s built-in matmul API ``tl.dot`` `__ to perform matrix multiplication. The ``tl.dot`` operation in Triton requires minimum dimensions of 16. The transpiler handles this by: 1. Using the closest power of 2 for optimal performance while ensuring dimensions are at least 16 through padding 2. Generating appropriate masks for the padded regions Shared Memory Limitations ~~~~~~~~~~~~~~~~~~~~~~~~~ The transpiler validates shared memory usage and handles excessive requirements by simply generating the temporary code as well and invalidating it when finding it’s not executable. Profiler Implementation ----------------------- TritonProfiler ~~~~~~~~~~~~~~ The profiler component: 1. Make temporary files of generated codes 2. Measures average execution time to evaluates multiple kernel variants 3. Filters out invalid kernels 4. Selects the best performing variant to return Key profiling steps: .. code:: python for idx, g in enumerate(all_graphs): if not g.valid_kernels(): print("muGraph {}: skipping since its shared memory usage exceed limit".format(idx)) continue # Warmup runs for _ in range(16): g(inputs=input_tensors) torch.cuda.synchronize() starter.record() for _ in range(1000): g(inputs=input_tensors) ender.record() torch.cuda.synchronize() perf = starter.elapsed_time(ender) / 1000 Usage Guide ----------- Basic Usage ~~~~~~~~~~~ By using ``graph.superoptimize`` and indicate backend to “triton”, mirage will automatically search, profile and return the best arrangement it found and return the related graph. You could then use ``generate_triton_program()["code"]`` to get the triton code and output it to a file for further use. Here’s a simple example of using the Triton transpiler, which you could find as ``demo/triton_rms_norm.py``: .. code:: python import mirage as mi import torch # Create a kernel graph graph = mi.new_kernel_graph() # Define inputs and operations X = graph.new_input(dims=(16, 4096), dtype=mi.float16) W = graph.new_input(dims=(4096, 6144), dtype=mi.float16) D = graph.rms_norm(X, normalized_shape=(4096,)) O = graph.matmul(D, W) graph.mark_output(O) # Optimize and generate Triton code optimized_graph = graph.superoptimize(config="mlp", backend="triton") # Generate and save the code with open("triton_generated.py", "w") as f: f.write(mi.generate_triton_program( optimized_graph.cygraph, target_cc=10)["code"]) Advanced Guide -------------- Debug Mode ~~~~~~~~~~ Inside ``python/mirage/kernel.py``, you could find ``triton_transpiler`` related code at: .. code:: python elif backend == "triton": return profile_and_select_best_graph(all_graphs, target_cc=torch.cuda.get_device_properties(0).major * 10 + torch.cuda.get_device_properties(0).minor, warmup_iters=16, profile_iters=1000, debug_mode=False) By indicating ``debug_mode`` as ``True``, you could get a more detailed output of every single code generated by searching process while code files will also be made under the directory, which will be useful to discover potential falses for these transpiled triton codes. Iteration Times ~~~~~~~~~~~~~~~ You could also modify the times for warmup iteration and profiling iteration by changing related parameters in the function above. By default the numbers are ``16`` and ``1000``.