diff --git a/.github/labeler.yml b/.github/labeler.yml index b852f4e32..ce4f78e50 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -414,6 +414,12 @@ - 'etc/config/toit.*.properties' - 'static/modes/toit-mode.ts' +'lang-triton': + - changed-files: + - any-glob-to-any-file: + - 'lib/compilers/triton.ts' + - 'etc/config/triton.*.properties' + 'lang-typescript': - changed-files: - any-glob-to-any-file: diff --git a/etc/config/triton.amazon.properties b/etc/config/triton.amazon.properties new file mode 100644 index 000000000..8fd65dc81 --- /dev/null +++ b/etc/config/triton.amazon.properties @@ -0,0 +1,52 @@ +compilers=&triton_nvidia:&triton_amd +defaultCompiler=triton_nvidia_331 +compilerType=triton +interpreted=true +supportsBinary=false +supportsExecute=false +isSemVer=true +notification=Experimental Triton support on Compiler Explorer. For tutorials, bugs reports, and feature requests, please visit here. + +group.triton_nvidia.compilers=triton_nvidia_331:triton_nvidia_330:triton_nvidia_320:triton_nvidia_310:triton_nvidia_300:triton_nvidia_231:triton_nvidia_230 +group.triton_nvidia.groupName=Triton (Nvidia) +group.triton_nvidia.options=--backend cuda --arch 90 --warp_size 32 + +compiler.triton_nvidia_331.name=Triton 3.3.1 (Nvidia) +compiler.triton_nvidia_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3 + +compiler.triton_nvidia_330.name=Triton 3.3.0 (Nvidia) +compiler.triton_nvidia_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3 + +compiler.triton_nvidia_320.name=Triton 3.2.0 (Nvidia) +compiler.triton_nvidia_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3 + +compiler.triton_nvidia_310.name=Triton 3.1.0 (Nvidia) +compiler.triton_nvidia_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3 + +compiler.triton_nvidia_300.name=Triton 3.0.0 (Nvidia) +compiler.triton_nvidia_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3 + +compiler.triton_nvidia_231.name=Triton 2.3.1 (Nvidia) +compiler.triton_nvidia_231.exe=/opt/compiler-explorer/triton/v2.3.1/bin/python3 + +compiler.triton_nvidia_230.name=Triton 2.3.0 (Nvidia) +compiler.triton_nvidia_230.exe=/opt/compiler-explorer/triton/v2.3.0/bin/python3 + +group.triton_amd.compilers=triton_amd_331:triton_amd_330:triton_amd_320:triton_amd_310:triton_amd_300 +group.triton_amd.groupName=Triton (AMD) +group.triton_amd.options=--backend hip --arch gfx942 --warp_size 32 + +compiler.triton_amd_331.name=Triton 3.3.1 (AMD) +compiler.triton_amd_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3 + +compiler.triton_amd_330.name=Triton 3.3.0 (AMD) +compiler.triton_amd_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3 + +compiler.triton_amd_320.name=Triton 3.2.0 (AMD) +compiler.triton_amd_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3 + +compiler.triton_amd_310.name=Triton 3.1.0 (AMD) +compiler.triton_amd_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3 + +compiler.triton_amd_300.name=Triton 3.0.0 (AMD) +compiler.triton_amd_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3 diff --git a/etc/config/triton.defaults.properties b/etc/config/triton.defaults.properties new file mode 100644 index 000000000..8fd65dc81 --- /dev/null +++ b/etc/config/triton.defaults.properties @@ -0,0 +1,52 @@ +compilers=&triton_nvidia:&triton_amd +defaultCompiler=triton_nvidia_331 +compilerType=triton +interpreted=true +supportsBinary=false +supportsExecute=false +isSemVer=true +notification=Experimental Triton support on Compiler Explorer. For tutorials, bugs reports, and feature requests, please visit here. + +group.triton_nvidia.compilers=triton_nvidia_331:triton_nvidia_330:triton_nvidia_320:triton_nvidia_310:triton_nvidia_300:triton_nvidia_231:triton_nvidia_230 +group.triton_nvidia.groupName=Triton (Nvidia) +group.triton_nvidia.options=--backend cuda --arch 90 --warp_size 32 + +compiler.triton_nvidia_331.name=Triton 3.3.1 (Nvidia) +compiler.triton_nvidia_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3 + +compiler.triton_nvidia_330.name=Triton 3.3.0 (Nvidia) +compiler.triton_nvidia_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3 + +compiler.triton_nvidia_320.name=Triton 3.2.0 (Nvidia) +compiler.triton_nvidia_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3 + +compiler.triton_nvidia_310.name=Triton 3.1.0 (Nvidia) +compiler.triton_nvidia_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3 + +compiler.triton_nvidia_300.name=Triton 3.0.0 (Nvidia) +compiler.triton_nvidia_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3 + +compiler.triton_nvidia_231.name=Triton 2.3.1 (Nvidia) +compiler.triton_nvidia_231.exe=/opt/compiler-explorer/triton/v2.3.1/bin/python3 + +compiler.triton_nvidia_230.name=Triton 2.3.0 (Nvidia) +compiler.triton_nvidia_230.exe=/opt/compiler-explorer/triton/v2.3.0/bin/python3 + +group.triton_amd.compilers=triton_amd_331:triton_amd_330:triton_amd_320:triton_amd_310:triton_amd_300 +group.triton_amd.groupName=Triton (AMD) +group.triton_amd.options=--backend hip --arch gfx942 --warp_size 32 + +compiler.triton_amd_331.name=Triton 3.3.1 (AMD) +compiler.triton_amd_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3 + +compiler.triton_amd_330.name=Triton 3.3.0 (AMD) +compiler.triton_amd_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3 + +compiler.triton_amd_320.name=Triton 3.2.0 (AMD) +compiler.triton_amd_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3 + +compiler.triton_amd_310.name=Triton 3.1.0 (AMD) +compiler.triton_amd_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3 + +compiler.triton_amd_300.name=Triton 3.0.0 (AMD) +compiler.triton_amd_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3 diff --git a/etc/scripts/triton_wrapper.py b/etc/scripts/triton_wrapper.py new file mode 100644 index 000000000..a13d8ee00 --- /dev/null +++ b/etc/scripts/triton_wrapper.py @@ -0,0 +1,254 @@ +import argparse +import importlib.util +import inspect +import json +import os +from pathlib import Path +from typing import Dict, Optional, Union +from unittest.mock import MagicMock + +import torch +import triton +from torch._subclasses.fake_tensor import FakeTensorMode + + +class MockCacheManager(triton.runtime.cache.CacheManager): + """ + A mock cache manager that dumps the intermediate files to a given output path. + + There are various ways to dump the intermediate files: + 1. The most obvious way is to use the `TRITON_KERNEL_DUMP` & ``TRITON_DUMP_DIR` + environment variables. e.g., + os.environ["TRITON_KERNEL_DUMP"] = "1" + os.environ["TRITON_DUMP_DIR"] = str(output_dir) + However, `TRITON_DUMP_DIR` is introduced in Triton v3.2.0 at + https://github.com/triton-lang/triton/commit/ca469d7b6b6def316b5f5ee6ad2bd19dcb840bd8, + and thus not available in older versions. + + 2. The second attempt is to patch the `default_cache_dir` function. e.g., + triton.runtime.cache.default_cache_dir = MagicMock(return_value=output_dir) + This is a bit hacky, and less flexible in terms of controlling the file output. + (In fact, Triton dumps the compiled kernels to a folder with a random name.) + + 3. Another option is to use various hooks in Triton. e.g., + triton.knobs.runtime.{jit_post_compile_hook,launch_enter_hook} + JITFunction.{compiled_hook,cache_hook} + This approach is taken by TritonParse(https://github.com/pytorch-labs/tritonparse), + but it does not support older versions of Triton prior to the following commits: + https://github.com/triton-lang/triton/commit/0e9267202532ed1709dcc12c636220cf239dc377, + https://github.com/triton-lang/triton/commit/850525276426fb9814399a8e0ee8fdf744229b02. + + 4. The current apporach is to mock a `CacheManager` class. This is the most flexible + approach, and works for all versions of Triton. + """ + + output_file: Path + + def __init__(self, key, override=False, dump=False): + self.dump = dump + # filename -> data + self.files = {} + # filename -> group dict + self.groups = {} + # current stage for a given kernel + self.stage = 0 + + def get_file(self, filename) -> Optional[str]: + return self.files.get(filename, None) + + def put(self, data, filename, binary=True) -> str: + name = Path(filename).stem + suffix = Path(filename).suffix + binary = isinstance(data, bytes) + if not binary: + data = str(data) + + # Write the final file to the output file, so that we can view it in the default assembly view. + if suffix in (".ptx", ".amdgcn"): + with open(MockCacheManager.output_file, "a") as f: + f.write(data) + f.write("\n\n") + + # Write intermediate files to the output file, so that we can see them in the Device View. + self.stage += 1 + if suffix == ".json": + path = MockCacheManager.output_file.parent / filename + with open(path, "w") as fout: + json.dump(json.loads(data), fout, indent=2) + else: + path = ( + MockCacheManager.output_file.parent + / f"{name} [stage {self.stage}]{suffix}" + ) + if not binary: + with open(path, "w") as fout: + fout.write(data) + elif suffix == ".cubin": + try: + # The try-catch is needed because `disasm` was broken in Triton v3.0.0 and v3.1.0. See + # https://github.com/triton-lang/triton/commit/f424f656b3528c47d8c48126cdccafca29e536ae + from triton.tools import disasm + + with open(path.with_suffix(".sass"), "w") as fout: + fout.write(disasm.get_sass(data)) + except Exception: + pass + + # Write the file to the "cache" + self.files[filename] = data + return filename + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + self.groups.get(filename, None) + + def put_group(self, filename: str, group: Dict[str, str]): + self.groups[filename] = group + + +def setup_triton( + output_file: Path, + opt_pipeline_file: Path, + backend: str, + arch: Union[int, str], + warp_size: int, +): + """ + Patch Triton to dump the compiled kernels to output dir without actually running them. + + This is needed because + 1. Triton does not easily support such use case. There exists an AOT compiler at + https://github.com/triton-lang/triton/blob/main/python/triton/tools/compile.py, + but it requires a bunch of boilerplate code and also requires additional user + input to specify the kernel name, signature, etc. + 2. Even if Triton adds such support, older versions of Triton (e.g., v2.3.x) still + requirs such patching to work. + + This function is a collection of hacks. It has been tested to work with Triton + 2.3.0, 2.3.1, 3.0.0, 3.1.0, 3.2.0, 3.3.0, 3.3.1. + """ + + os.environ["TRITON_ALWAYS_COMPILE"] = "1" + if opt_pipeline_file: + os.environ["MLIR_ENABLE_DUMP"] = "1" + # Supported in Triton v3.3.0 and later since + # https://github.com/triton-lang/triton/commit/3d7d9e33e7e4cba17dc366d207af2c657bd4fbd1 + os.environ["MLIR_DUMP_PATH"] = str(opt_pipeline_file) + else: + # Disable dumping other files w/ opt_pipeline_file since they race with each other + os.environ["TRITON_CACHE_MANAGER"] = "__main__:MockCacheManager" + MockCacheManager.output_file = output_file + + # Usually, Triton compiles and run a kernel when we call `kernel[grid](args)`. + # However, we want to dump the compiled kernel without actually running it. + # The class `CompiledKernel` represents a handle to a compiled kernel, + # ready to be launched. We patch it to be a no-op. + triton.compiler.compiler.CompiledKernel = MagicMock() + + # We mock a GPU driver to avoid the need to initialize CUDA/ROCm. + # The driver is only used in runtime instead of compile time, + # so it's safe to do this. + def get_current_target(): + try: + from triton.compiler.compiler import GPUTarget + + return GPUTarget(backend=backend, arch=arch, warp_size=warp_size) + except ImportError: + # For Triton v2.3.x, we don't have GPUTarget + return (backend, arch) + + mockGPUDriver = MagicMock( + get_current_target=get_current_target, + get_benchmarker=lambda: MagicMock(return_value=[0.0]), + ) + + # Set the active driver to the mocked one. + # `DriverConfig` and `triton.runtime.driver.set_active` is introduced in Triton v3.0.0 at + # https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4 + # For older versions of Triton, we directly assign to the `_obj` field of `LazyProxy`. + try: + from triton.runtime.driver import DriverConfig + + triton.runtime.driver.set_active(mockGPUDriver) + except ImportError: + triton.runtime.driver._obj = mockGPUDriver + + # For Triton v2.3.x, there are some driver code that goes into + # the generic code path, so we need to patch it as well. + try: + from triton.compiler.backends.cuda import CUDABackend + + CUDABackend.make_launcher_stub = MagicMock() + except ImportError: + pass + + +def main( + input_file: Path, + output_file: Path, + opt_pipeline_file: Path, + backend: str, + arch: Union[int, str], + warp_size: int, +): + setup_triton(output_file, opt_pipeline_file, backend, arch, warp_size) + + # Run the script by importing it as a module + spec = importlib.util.spec_from_file_location("example", input_file) + module = importlib.util.module_from_spec(spec) + with FakeTensorMode(): + # Use FakeTensor (developed during Dynamo) to avoid actually creating tensors + # See https://docs.pytorch.org/docs/stable/torch.compiler_fake_tensor.html + # Also set the data_ptr to 0 to avoid PyTorch warning and make alignment check happy + torch._subclasses.FakeTensor.data_ptr = MagicMock(return_value=0) + spec.loader.exec_module(module) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Triton wrapper") + parser.add_argument( + "input_file", + type=Path, + help="Path to the input Python file", + ) + parser.add_argument( + "--output_file", + type=Path, + required=True, + help="Path to the output file", + ) + parser.add_argument( + "--opt_pipeline_file", + type=Path, + help="Path to the output opt pipeline file", + ) + parser.add_argument( + "--backend", + type=str, + default="cuda", + choices=["cuda", "hip"], + ) + parser.add_argument( + "--arch", + type=str, + default=None, # Default value set later based on backend + ) + parser.add_argument( + "--warp_size", + type=int, + default=32, + ) + + args = parser.parse_args() + + # Set some sane defaults for the arch + if args.arch is None: + if args.backend == "cuda": + args.arch = 90 + elif args.backend == "hip": + args.arch = "gfx942" + + # Triton expects the arch to be an int for CUDA and a string for HIP + if args.backend == "cuda": + args.arch = int(args.arch) + + main(**vars(args)) diff --git a/examples/triton/default.py b/examples/triton/default.py new file mode 100644 index 000000000..54838b155 --- /dev/null +++ b/examples/triton/default.py @@ -0,0 +1,28 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +x = torch.rand(1024) +y = torch.rand(1024) +output = torch.empty(1024) +n_elements = output.numel() +add_kernel[(1,)](x, y, output, n_elements, BLOCK_SIZE=1024) diff --git a/lib/compilers/_all.ts b/lib/compilers/_all.ts index 47b976544..de2579e23 100644 --- a/lib/compilers/_all.ts +++ b/lib/compilers/_all.ts @@ -154,6 +154,7 @@ export {TenDRACompiler} from './tendra.js'; export {TIC2000} from './tic2000.js'; export {TinyCCompiler} from './tinyc.js'; export {ToitCompiler} from './toit.js'; +export {TritonCompiler} from './triton.js'; export {TurboCCompiler} from './turboc.js'; export {TypeScriptNativeCompiler} from './typescript-native.js'; export {VCompiler} from './v.js'; diff --git a/lib/compilers/triton.ts b/lib/compilers/triton.ts new file mode 100644 index 000000000..dd91b09df --- /dev/null +++ b/lib/compilers/triton.ts @@ -0,0 +1,201 @@ +// Copyright (c) 2025, Compiler Explorer Authors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +import * as fs from 'node:fs/promises'; +import Path from 'node:path'; +import type {CompilationInfo, CompilationResult} from '../../types/compilation/compilation.interfaces.js'; +import type { + OptPipelineBackendOptions, + OptPipelineOutput, +} from '../../types/compilation/opt-pipeline-output.interfaces.js'; +import type {PreliminaryCompilerInfo} from '../../types/compiler.interfaces.js'; +import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js'; +import {BaseCompiler} from '../base-compiler.js'; +import {CompilationEnvironment} from '../compilation-env.js'; +import type {IAsmParser} from '../parsers/asm-parser.interfaces.js'; +import {AmdgpuAsmParser} from '../parsers/asm-parser-amdgpu.js'; +import {MlirAsmParser} from '../parsers/asm-parser-mlir.js'; +import {PTXAsmParser} from '../parsers/asm-parser-ptx.js'; +import {SassAsmParser} from '../parsers/asm-parser-sass.js'; +import {MlirPassDumpParser} from '../parsers/mlir-pass-dump-parser.js'; +import {parseOutput, resolvePathFromAppRoot} from '../utils.js'; +import {BaseParser} from './argument-parsers.js'; + +export class TritonCompiler extends BaseCompiler { + private compilerWrapperPath: string; + + static get key() { + return 'triton'; + } + + parserMap: Record; + mlirPassDumpParser: MlirPassDumpParser; + + constructor(compilerInfo: PreliminaryCompilerInfo, env: CompilationEnvironment) { + super(compilerInfo, env); + + this.compilerWrapperPath = + this.compilerProps('compilerWrapper', '') || resolvePathFromAppRoot('etc', 'scripts', 'triton_wrapper.py'); + + // Enable the Opt Pipeline view + this.compiler.optPipeline = {}; + // Used to parse the output of the opt pipeline + this.mlirPassDumpParser = new MlirPassDumpParser(this.compilerProps); + + // Enable the Device Viewer + this.compiler.supportsDeviceAsmView = true; + // Define parsers for the different output files displayed in the Device Viewer + const sassAsmParser = new SassAsmParser(this.compilerProps); + const ptxAsmParser = new PTXAsmParser(this.compilerProps); + const amdgpuAsmParser = new AmdgpuAsmParser(); + const mlirAsmParser = new MlirAsmParser(); + this.parserMap = { + '.ttir': mlirAsmParser, + '.ttgir': mlirAsmParser, + '.ptx': ptxAsmParser, + '.sass': sassAsmParser, + '.source': mlirAsmParser, + '.amdgcn': amdgpuAsmParser, + '.llir': mlirAsmParser, + '.json': sassAsmParser, + }; + + if (compilerInfo.group == 'triton_amd') { + this.asm = amdgpuAsmParser; + } else if (compilerInfo.group == 'triton_nvidia') { + this.asm = ptxAsmParser; + } + } + + override optionsForFilter(filters: ParseFiltersAndOutputOptions, outputFilename: string): string[] { + // See etc/scripts/triton_wrapper.py for the options + return ['-I', this.compilerWrapperPath, '--output_file', outputFilename]; + } + + override getArgumentParserClass() { + return BaseParser; + } + + override async extractDeviceCode( + result: CompilationResult, + filters: ParseFiltersAndOutputOptions, + compilationInfo: CompilationInfo, + ) { + const devices = {...result.devices}; + + const {dirPath} = result; + if (!dirPath) { + return result; + } + + // Extract the device code from the output directory + const files = await fs.readdir(dirPath); + await Promise.all( + files.map(async filename => { + const ext = Path.extname(filename); + const parser = this.parserMap[ext]; + if (!parser) { + return; + } + + // Read the file + const data = await fs.readFile(Path.join(dirPath, filename), 'utf8'); + + // Parse the assembly with line numbers + let device; + if (ext === '.llir') { + device = await this.llvmIr.process(data, { + filterDebugInfo: false, + filterIRMetadata: false, + filterAttributes: false, + filterComments: false, + noDiscardValueNames: false, + demangle: false, + }); + } else { + device = await parser.process(data, filters); + } + + Object.assign(devices, {[filename]: device}); + }), + ); + result.devices = devices; + return result; + } + + override async generateOptPipeline( + inputFilename: string, + options: string[], + filters: ParseFiltersAndOutputOptions, + optPipelineOptions: OptPipelineBackendOptions, + ): Promise { + // Call the script to generate the opt pipeline + const execOptions = this.getDefaultExecOptions(); + const outputFilename = Path.join(Path.dirname(inputFilename), 'opt_pipeline.txt'); + const optOptions = [...options, '--opt_pipeline_file', outputFilename]; + + const compileStart = performance.now(); + await this.runCompiler(this.compiler.exe, optOptions, inputFilename, execOptions); + const compileEnd = performance.now(); + + // Read the output file and parse it + try { + const rawText = await fs.readFile(outputFilename, {encoding: 'utf8'}); + const lines = parseOutput(rawText); + + const parseStart = performance.now(); + const llvmOptPipeline = await this.mlirPassDumpParser.process(lines, filters, optPipelineOptions); + const parseEnd = performance.now(); + + return { + results: llvmOptPipeline, + compileTime: compileEnd - compileStart, + parseTime: parseEnd - parseStart, + }; + } catch (e: any) { + return { + error: e.toString(), + results: {}, + compileTime: compileEnd - compileStart, + }; + } + } + + override getDefaultFilters() { + return { + intel: false, + commentOnly: false, + directives: false, + labels: false, + optOutput: true, + binary: false, + execute: false, + demangle: false, + libraryCode: false, + trim: false, + binaryObject: false, + debugCalls: false, + }; + } +} diff --git a/lib/languages.ts b/lib/languages.ts index 7ea3df222..d54b32331 100644 --- a/lib/languages.ts +++ b/lib/languages.ts @@ -894,6 +894,17 @@ const definitions: Record = { previewFilter: null, monacoDisassembly: null, }, + triton: { + name: 'Triton', + monaco: 'python', + extensions: ['.py'], + alias: [], + logoFilename: 'triton.png', + logoFilenameDark: null, + formatter: null, + previewFilter: null, + monacoDisassembly: null, + }, typescript: { name: 'TypeScript Native', monaco: 'typescript', diff --git a/lib/parsers/asm-parser-mlir.ts b/lib/parsers/asm-parser-mlir.ts new file mode 100644 index 000000000..14feca001 --- /dev/null +++ b/lib/parsers/asm-parser-mlir.ts @@ -0,0 +1,146 @@ +// Copyright (c) 2025, Compiler Explorer Authors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +import {AsmResultSource, ParsedAsmResult, ParsedAsmResultLine} from '../../types/asmresult/asmresult.interfaces.js'; +import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js'; +import * as utils from '../utils.js'; + +import {AsmParser} from './asm-parser.js'; + +export class MlirAsmParser extends AsmParser { + protected locDefRegex: RegExp; + protected locDefUnknownRegex: RegExp; + protected locRefRegex: RegExp; + protected locRefRegexReplace: RegExp; + protected inlineLocRegex: RegExp; + protected inlineLocRegexReplace: RegExp; + + constructor() { + super(); + + // Match location definitions like #loc1 = loc("/path/to/file":line:column) + this.locDefRegex = /^#(\w+)\s*=\s*loc\("([^"]+)":(\d+):(\d+)\)/; + + // Match location definitions like #loc1 = loc(unknown) + this.locDefUnknownRegex = /^#(\w+)\s*=\s*loc\(unknown\)/; + + // Match location references like loc(#loc1) + this.locRefRegex = /\s*loc\(#(\w+)\)/; + this.locRefRegexReplace = new RegExp(this.locRefRegex.source, 'g'); + + // Match inline locations like loc("/path/to/file":line:column) + this.inlineLocRegex = /\s*loc\("([^"]+)":(\d+):(\d+)\)/; + this.inlineLocRegexReplace = new RegExp(this.inlineLocRegex.source, 'g'); + } + + override processAsm(asmResult: string, filters: ParseFiltersAndOutputOptions): ParsedAsmResult { + const startTime = process.hrtime.bigint(); + + const asm: ParsedAsmResultLine[] = []; + const asmLines = utils.splitLines(asmResult); + const startingLineCount = asmLines.length; + + // First pass: extract all location definitions + const locationMap = new Map(); + for (const line of asmLines) { + const locMatch = line.match(this.locDefRegex); + if (locMatch) { + const locId = locMatch[1]; + const file = locMatch[2]; + const lineNum = Number.parseInt(locMatch[3], 10); + const column = Number.parseInt(locMatch[4], 10); + + locationMap.set(locId, { + file: utils.maskRootdir(file), + line: lineNum, + column: column, + mainsource: true, + }); + } + } + + // Second pass: process each line and associate with source information + for (const line of asmLines) { + // Skip location definition lines + if (this.locDefRegex.test(line) || this.locDefUnknownRegex.test(line)) { + continue; + } + + // Apply filters if needed + let processedLine = line; + if (filters.trim) { + processedLine = processedLine.trim(); + } + + if (filters.commentOnly && processedLine.trim().startsWith('//')) { + continue; + } + + // Find source information from location references + let source: AsmResultSource | null = null; + + // Check for location references like loc(#loc1) + const locRefMatch = line.match(this.locRefRegex); + if (locRefMatch) { + const locId = locRefMatch[1]; + source = locationMap.get(locId) || null; + // Remove location reference from the displayed text + processedLine = processedLine.replace(this.locRefRegexReplace, ''); + } else { + // Check for inline locations like loc("/path/to/file":line:column) + const inlineLocMatch = line.match(this.inlineLocRegex); + if (inlineLocMatch) { + const file = inlineLocMatch[1]; + const lineNum = Number.parseInt(inlineLocMatch[2], 10); + const column = Number.parseInt(inlineLocMatch[3], 10); + + source = { + file: utils.maskRootdir(file), + line: lineNum, + column: column, + mainsource: true, + }; + } + // Remove inline location from the displayed text + processedLine = processedLine.replace(this.inlineLocRegexReplace, ''); + } + + // Add the line to the result + asm.push({ + text: processedLine, + source: source, + labels: [], + }); + } + + const endTime = process.hrtime.bigint(); + return { + asm: asm, + labelDefinitions: {}, + languageId: 'mlir', + parsingTime: utils.deltaTimeNanoToMili(startTime, endTime), + filteredCount: startingLineCount - asm.length, + }; + } +} diff --git a/lib/parsers/mlir-pass-dump-parser.ts b/lib/parsers/mlir-pass-dump-parser.ts new file mode 100644 index 000000000..7f53458c4 --- /dev/null +++ b/lib/parsers/mlir-pass-dump-parser.ts @@ -0,0 +1,374 @@ +// Copyright (c) 2025, Compiler Explorer Authors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +import { + OptPipelineBackendOptions, + OptPipelineResults, + Pass, +} from '../../types/compilation/opt-pipeline-output.interfaces.js'; +import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js'; +import {ResultLine} from '../../types/resultline/resultline.interfaces.js'; +import {assert} from '../assert.js'; +import {PropertyGetter} from '../properties.interfaces.js'; + +// Helper function to extract pass name from header +function extractPassName(header: string): string { + if (header.startsWith('IR Dump Before ')) { + return header.slice('IR Dump Before '.length); + } + + if (header.startsWith('IR Dump After ')) { + let passName = header.slice('IR Dump After '.length); + // Handle invalidated passes + if (passName.endsWith(' (invalidated)')) { + passName = passName.slice(0, passName.length - ' (invalidated)'.length); + } + return passName; + } + assert(false, 'Unexpected pass header', header); +} + +// Ir Dump for a pass with raw lines +type PassDump = { + header: string; + affectedFunction: string | undefined; + lines: ResultLine[]; +}; + +// Ir Dump for a pass with raw lines broken into affected functions +type SplitPassDump = { + header: string; + functions: Record; +}; + +export class MlirPassDumpParser { + locationDefine: RegExp; + locationReference: RegExp; + irDumpHeader: RegExp; + functionDefine: RegExp; + moduleDefine: RegExp; + functionEnd: RegExp; + moduleEnd: RegExp; + + constructor(compilerProps: PropertyGetter) { + // Location definitions: #loc0 = loc(...) + this.locationDefine = /^#loc\d* = loc\(.+\)$/; + + // Location references: loc(#loc0), loc("/app/example.py":19:0), loc(unknown) + this.locationReference = /\s*loc\([^)]*\)/g; + + // MLIR dump headers look like "// -----// IR Dump Before/After XYZ (xyz) ('operation-type' operation: @function_name) //----- //" + this.irDumpHeader = /^\/\/ -----\/\/ (IR Dump (?:Before|After) .+) \/\/----- \/\/$/; + + // MLIR function definitions look like "func.func @function_name(...) {" + // or "tt.func public @function_name(...) {" + this.functionDefine = /^\s*(\w+\.func\s+(?:\w+\s+)?@(\w+).*\{)$/; + + // MLIR module definitions look like "module {" + this.moduleDefine = /^\s*(module\s*\{)$/; + + // Functions end with a closing brace + this.functionEnd = /^\s*\}\s*$/; + + // Modules end with a closing brace + this.moduleEnd = /^\s*\}\s*$/; + } + + breakdownOutputIntoPassDumps(ir: ResultLine[]) { + // break down output by "// -----// IR Dump Before/After XYZ //----- //" markers + const raw_passes: PassDump[] = []; + let pass: PassDump | null = null; + let lastWasBlank = false; // skip duplicate blank lines + + for (const line of ir) { + const irMatch = line.text.match(this.irDumpHeader); + + if (irMatch) { + if (pass !== null) { + raw_passes.push(pass); + } + + const headerText = irMatch[1]; + pass = { + header: headerText, + affectedFunction: undefined, + lines: [], + }; + lastWasBlank = true; // skip leading newlines after the header + } else { + if (pass === null) continue; + + if (line.text.trim() === '') { + if (!lastWasBlank) { + pass.lines.push(line); + } + lastWasBlank = true; + } else { + pass.lines.push(line); + lastWasBlank = false; + } + } + } + + if (pass !== null) { + raw_passes.push(pass); + } + + return raw_passes; + } + + breakdownPassDumpsIntoFunctions(dump: PassDump) { + // Simplified version based on the assumption that: + // 1. Functions always live inside a single module + // 2. A single module always has a single function in it + // 3. We use the name of the function to show the entire module + const pass: SplitPassDump = { + header: dump.header, + functions: {}, + }; + + // Find the function name inside the module + let functionName: string | null = null; + for (const line of dump.lines) { + const funcMatch = line.text.match(this.functionDefine); + if (funcMatch) { + functionName = funcMatch[2]; + break; + } + } + + // If we found a function name, use it; otherwise use "module" + const name = functionName || 'module'; + pass.functions[name] = dump.lines; + + return pass; + } + + breakdownIntoPassDumpsByFunction(passDumps: SplitPassDump[]) { + // Currently we have an array of passes with a map of functions altered in each pass + // We want to transpose to a map of functions with an array of passes on the function + const passDumpsByFunction: Record = {}; + + for (const pass of passDumps) { + const {header, functions} = pass; + + for (const [function_name, lines] of Object.entries(functions)) { + if (!(function_name in passDumpsByFunction)) { + passDumpsByFunction[function_name] = []; + } + + passDumpsByFunction[function_name].push({ + header, + affectedFunction: undefined, + lines, + }); + } + } + + return passDumpsByFunction; + } + + associateFullDumpsWithFunctions(passDumps: PassDump[]) { + // Currently we have an array of passes that'll have target annotations + const passDumpsByFunction: Record = {}; + + // First figure out what all the functions are + for (const pass of passDumps) { + if (pass.affectedFunction) { + passDumpsByFunction[pass.affectedFunction] = []; + } + } + + // Add a special entry for the full module + passDumpsByFunction[''] = []; + + for (const pass of passDumps) { + const {header, affectedFunction, lines} = pass; + + if (affectedFunction) { + // This pass affects a specific function + assert(affectedFunction in passDumpsByFunction); + + // Add to both the specific function and the full module view + [passDumpsByFunction[affectedFunction], passDumpsByFunction['']].map(entry => + entry.push({ + header: `${header} (${affectedFunction})`, + affectedFunction, + lines, + }), + ); + } else { + // This pass applies to everything + for (const [_, entry] of Object.entries(passDumpsByFunction)) { + entry.push({ + header, + affectedFunction: undefined, + lines, + }); + } + } + } + + return passDumpsByFunction; + } + + isIrChanged(before: ResultLine[], after: ResultLine[]) { + if (before.length !== after.length) { + return true; + } + for (let i = 0; i < before.length; i++) { + if (before[i].text !== after[i].text) { + return true; + } + } + return false; + } + + matchPassDumps(passDumpsByFunction: Record) { + // We have all the passes for each function, now we will go through each function and match the before/after + // dumps, handling the case where the same pass might occur multiple times + const final_output: OptPipelineResults = {}; + + for (const [function_name, passDumps] of Object.entries(passDumpsByFunction)) { + const passes: Pass[] = []; + + // Use a stack of "Before" passes + const beforePasses: PassDump[] = []; + + // Collect all "Before" passes in order + for (const dump of passDumps) { + if (dump.header.startsWith('IR Dump Before ')) { + beforePasses.push(dump); + } + } + + // Process "After" passes and match with "Before" passes + for (const dump of passDumps) { + if (dump.header.startsWith('IR Dump After ')) { + const afterPassName = extractPassName(dump.header); + const pass: Pass = { + name: afterPassName, + machine: false, + after: dump.lines, + before: [], + irChanged: true, + }; + + // Find matching "Before" pass by name + for (let i = 0; i < beforePasses.length; i++) { + const beforePassName = extractPassName(beforePasses[i].header); + if (beforePassName === afterPassName) { + // Found a match, use it and remove from the stack + pass.before = beforePasses[i].lines; + + // Check for equality + pass.irChanged = this.isIrChanged(pass.before, pass.after); + + // Remove the matched "Before" pass + beforePasses.splice(i, 1); + break; + } + } + + passes.push(pass); + } + } + + // If we only have before passes (no after passes), diff between consecutive before passes + // This happened in Triton since it sets enableIRPrinting(printAfterOnlyOnFailure=false) + if (passes.length === 0) { + for (let i = 0; i < beforePasses.length - 1; i++) { + const isLast = i === beforePasses.length - 1; + const passName = extractPassName(beforePasses[i].header); + const before = beforePasses[i].lines; + const after = isLast ? beforePasses[i].lines : beforePasses[i + 1].lines; + const irChanged = isLast ? false : this.isIrChanged(before, after); + const pass: Pass = { + name: passName, + machine: false, + before: before, + after: after, + irChanged: irChanged, + }; + passes.push(pass); + } + } else { + // Handle any remaining "Before" passes that don't have corresponding "After" passes + for (const beforeDump of beforePasses) { + const passName = extractPassName(beforeDump.header); + const pass: Pass = { + name: passName, + machine: false, + before: beforeDump.lines, + after: [], + irChanged: true, // Assume changed since there's no "After" to compare with + }; + passes.push(pass); + } + } + + final_output[function_name] = passes; + } + + return final_output; + } + + breakdownOutput(ir: ResultLine[], optPipelineOptions: OptPipelineBackendOptions) { + // break down output by "// -----// IR Dump Before/After XYZ //----- //" markers + const raw_passes = this.breakdownOutputIntoPassDumps(ir); + + if (optPipelineOptions.fullModule) { + const passDumpsByFunction = this.associateFullDumpsWithFunctions(raw_passes); + // Match before / after pass dumps and we're done + return this.matchPassDumps(passDumpsByFunction); + } + + // Further break down by functions in each dump + const passDumps = raw_passes.map(this.breakdownPassDumpsIntoFunctions.bind(this)); + + // Transform array of passes containing multiple functions into a map from functions to arrays of passes on + // those functions + const passDumpsByFunction = this.breakdownIntoPassDumpsByFunction(passDumps); + + // Match before / after pass dumps and we're done + return this.matchPassDumps(passDumpsByFunction); + } + + applyIrFilters(ir: ResultLine[]) { + return ir + .filter(line => line.text.match(this.locationDefine) === null) + .map(line => ({ + ...line, + text: line.text.replace(this.locationReference, ''), + })); + } + + process(output: ResultLine[], _: ParseFiltersAndOutputOptions, optPipelineOptions: OptPipelineBackendOptions) { + // Crop out any junk before the pass dumps + const ir = output.slice(output.findIndex(line => this.irDumpHeader.test(line.text))); + + const preprocessed_lines = this.applyIrFilters(ir); + return this.breakdownOutput(preprocessed_lines, optPipelineOptions); + } +} diff --git a/public/logos/triton.png b/public/logos/triton.png new file mode 100644 index 000000000..88961049b Binary files /dev/null and b/public/logos/triton.png differ diff --git a/test/asm-parser-tests.ts b/test/asm-parser-tests.ts index 2368b5402..9f58ce708 100644 --- a/test/asm-parser-tests.ts +++ b/test/asm-parser-tests.ts @@ -24,6 +24,7 @@ import {describe, expect, it} from 'vitest'; import {AsmParser} from '../lib/parsers/asm-parser.js'; +import {MlirAsmParser} from '../lib/parsers/asm-parser-mlir.js'; import {PTXAsmParser} from '../lib/parsers/asm-parser-ptx.js'; describe('AsmParser tests', () => { @@ -196,3 +197,64 @@ vprintf, }); }); }); + +describe('MlirAsmParser tests', () => { + const parser = new MlirAsmParser(); + + describe('Location handling', () => { + it('should process MLIR with location references', () => { + const input = ` +#loc = loc("":7:0) +module { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("":7:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("":7:0), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc("":7:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("":7:0)) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + } loc(#loc) +} loc(#loc) +#loc1 = loc(unknown) +#loc2 = loc("":14:24)`; + + const result = parser.processAsm(input, {}); + + // Verify location definitions are removed + expect(result.asm.find(line => line.text.includes('#loc = loc'))).toBeUndefined(); + expect(result.asm.find(line => line.text.includes('#loc1 = loc'))).toBeUndefined(); + expect(result.asm.find(line => line.text.includes('#loc2 = loc'))).toBeUndefined(); + + // Verify location references are removed from displayed text + expect(result.asm.find(line => line.text.includes('loc(#loc)'))).toBeUndefined(); + expect(result.asm.find(line => line.text.includes('loc(#loc1)'))).toBeUndefined(); + expect(result.asm.find(line => line.text.includes('loc(#loc2)'))).toBeUndefined(); + + // Verify inline locations are removed + expect(result.asm.find(line => line.text.includes('loc(""'))).toBeUndefined(); + + // Verify source information is correctly associated + const programIdLine = result.asm.find(line => line.text.includes('tt.get_program_id')); + expect(programIdLine).toBeDefined(); + expect(programIdLine?.source).toBeDefined(); + expect(programIdLine?.source?.file).toBe(''); + expect(programIdLine?.source?.line).toBe(14); + expect(programIdLine?.source?.column).toBe(24); + + // Verify unknown locations are not associated with source information + const constantLine = result.asm.find(line => line.text.includes('arith.constant')); + expect(constantLine).toBeDefined(); + expect(constantLine?.source).toBeNull(); + + // Verify the structure is preserved + const moduleStartLine = result.asm.find(line => line.text.includes('module {')); + expect(moduleStartLine).toBeDefined(); + + const funcLine = result.asm.find(line => line.text.includes('tt.func public @add_kernel(')); + expect(funcLine).toBeDefined(); + expect(funcLine?.text?.includes('%arg0: !tt.ptr {tt.divisibility = 16 : i32}, ')).toBe(true); + expect(funcLine?.text?.includes('%arg1: !tt.ptr {tt.divisibility = 16 : i32}, ')).toBe(true); + expect(funcLine?.text?.includes('%arg2: !tt.ptr {tt.divisibility = 16 : i32}, ')).toBe(true); + expect(funcLine?.text?.includes('%arg3: i32 {tt.divisibility = 16 : i32})')).toBe(true); + + const constLine = result.asm.find(line => line.text.includes('arith.constant 1024')); + expect(constLine).toBeDefined(); + }); + }); +}); diff --git a/test/mlir-pass-dump-parser-tests.ts b/test/mlir-pass-dump-parser-tests.ts new file mode 100644 index 000000000..d58aa1649 --- /dev/null +++ b/test/mlir-pass-dump-parser-tests.ts @@ -0,0 +1,213 @@ +// Copyright (c) 2025, Compiler Explorer Authors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +import {beforeAll, describe, expect, it} from 'vitest'; + +import {MlirPassDumpParser} from '../lib/parsers/mlir-pass-dump-parser.js'; +import * as properties from '../lib/properties.js'; + +function deepCopy(obj: T): T { + return JSON.parse(JSON.stringify(obj)); +} + +describe('mlir-pass-dump-parser', () => { + let mlirPassDumpParser: MlirPassDumpParser; + + beforeAll(() => { + const fakeProps = new properties.CompilerProps({mlir: {id: 'mlir'}}, properties.fakeProps({})); + const compilerProps = (fakeProps.get as any).bind(fakeProps, 'mlir'); + mlirPassDumpParser = new MlirPassDumpParser(compilerProps); + }); + + const rawMlirDump = [ + { + text: "// -----// IR Dump Before Inliner (inline) ('builtin.module' operation) //----- //", + }, + {text: 'module {'}, + { + text: ' tt.func public @add_kernel() attributes {noinline = false} {', + }, + {text: ' %0 = tt.get_program_id x : i32 loc(#loc1)'}, + {text: ' tt.return loc(#loc2)'}, + {text: ' } loc(#loc)'}, + {text: '} loc(#loc)'}, + {text: '#loc = loc("/app/example.py":7:0)'}, + {text: '#loc1 = loc("/app/example.py":8:24)'}, + {text: '#loc2 = loc("/app/example.py":8:4)'}, + {text: ''}, + {text: ''}, + { + text: "// -----// IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel) //----- //", + }, + {text: 'module {'}, + { + text: ' tt.func public @add_kernel() attributes {noinline = false} {', + }, + {text: ' %0 = tt.get_program_id x : i32 loc(#loc1)'}, + {text: ' tt.return loc(#loc2)'}, + {text: ' } loc(#loc)'}, + {text: '} loc(#loc)'}, + {text: '#loc = loc("/app/example.py":7:0)'}, + {text: '#loc1 = loc("/app/example.py":8:24)'}, + {text: '#loc2 = loc("/app/example.py":8:4)'}, + {text: ''}, + {text: ''}, + { + text: "// -----// IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation) //----- //", + }, + {text: 'module {'}, + { + text: ' tt.func public @add_kernel() attributes {noinline = false} {', + }, + {text: ' tt.return loc(#loc1)'}, + {text: ' } loc(#loc)'}, + {text: '} loc(#loc)'}, + {text: '#loc = loc("/app/example.py":7:0)'}, + {text: '#loc1 = loc("/app/example.py":8:4)'}, + {text: ''}, + {text: ''}, + ]; + + it('should break down output into pass dumps', () => { + const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump)); + + expect(passDumps.length).toBe(3); + expect(passDumps[0].header).toBe("IR Dump Before Inliner (inline) ('builtin.module' operation)"); + expect(passDumps[1].header).toBe( + "IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)", + ); + expect(passDumps[2].header).toBe( + "IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation)", + ); + + // Check that the first pass dump has the correct lines + expect(passDumps[0].lines.length).toBe(10); + expect(passDumps[0].lines[0].text).toBe('module {'); + expect(passDumps[0].lines[1].text).toBe(' tt.func public @add_kernel() attributes {noinline = false} {'); + }); + + it('should break down pass dumps into functions', () => { + const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump)); + const splitPassDump = mlirPassDumpParser.breakdownPassDumpsIntoFunctions(passDumps[0]); + + expect(splitPassDump.header).toBe("IR Dump Before Inliner (inline) ('builtin.module' operation)"); + expect(Object.keys(splitPassDump.functions)).toContain('add_kernel'); + expect(splitPassDump.functions['add_kernel'].length).toBe(10); + }); + + it('should apply IR filters to remove location information', () => { + const filtered = mlirPassDumpParser.applyIrFilters(deepCopy(rawMlirDump.slice(0, 7))); + + // Should filter out location references but keep the lines + expect(filtered.length).toBe(7); + expect(filtered[3].text).toBe(' %0 = tt.get_program_id x : i32'); + expect(filtered[4].text).toBe(' tt.return'); + expect(filtered[5].text).toBe(' }'); + expect(filtered[6].text).toBe('}'); + }); + + it('should break down output into pass dumps by function', () => { + const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump)); + const splitPassDumps = passDumps.map(dump => mlirPassDumpParser.breakdownPassDumpsIntoFunctions(dump)); + const passDumpsByFunction = mlirPassDumpParser.breakdownIntoPassDumpsByFunction(splitPassDumps); + + expect(Object.keys(passDumpsByFunction)).toContain('add_kernel'); + expect(passDumpsByFunction['add_kernel'].length).toBe(3); + + // Check that the function has all three passes + const headers = passDumpsByFunction['add_kernel'].map(dump => dump.header); + expect(headers).toContain("IR Dump Before Inliner (inline) ('builtin.module' operation)"); + expect(headers).toContain("IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)"); + expect(headers).toContain( + "IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation)", + ); + }); + + it('should detect IR changes between passes', () => { + // Create two different IR dumps to test change detection + const before = [ + {text: 'module {'}, + {text: ' tt.func public @add_kernel() {'}, + {text: ' %0 = tt.get_program_id x : i32'}, + {text: ' tt.return'}, + {text: ' }'}, + {text: '}'}, + ]; + + const afterNoChange = deepCopy(before); + expect(mlirPassDumpParser.isIrChanged(before, afterNoChange)).toBe(false); + + const afterWithChange = [ + {text: 'module {'}, + {text: ' tt.func public @add_kernel() {'}, + {text: ' tt.return'}, // Line removed + {text: ' }'}, + {text: '}'}, + ]; + expect(mlirPassDumpParser.isIrChanged(before, afterWithChange)).toBe(true); + }); + + it('should match pass dumps and detect changes', () => { + const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump)); + const splitPassDumps = passDumps.map(dump => mlirPassDumpParser.breakdownPassDumpsIntoFunctions(dump)); + const passDumpsByFunction = mlirPassDumpParser.breakdownIntoPassDumpsByFunction(splitPassDumps); + const matchedPasses = mlirPassDumpParser.matchPassDumps(passDumpsByFunction); + + expect(Object.keys(matchedPasses)).toContain('add_kernel'); + + const addKernelPasses = matchedPasses['add_kernel']; + expect(addKernelPasses.length).toBe(2); // We should have 2 passes (comparing the 3 dumps) + + // Check the first pass (Inliner to Canonicalizer) + expect(addKernelPasses[0].name).toBe("Inliner (inline) ('builtin.module' operation)"); + expect(addKernelPasses[0].irChanged).toBe(false); // No changes between these two dumps + + // Check the second pass (Canonicalizer to TritonRewriteTensorPointer) + expect(addKernelPasses[1].name).toBe("Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)"); + expect(addKernelPasses[1].irChanged).toBe(true); // There are changes between these two dumps + }); + + it('should process the complete output correctly', () => { + const optPipelineOptions = { + fullModule: false, + filterDebugInfo: false, + filterIRMetadata: false, + noDiscardValueNames: true, + demangle: true, + libraryFunctions: false, + }; + const result = mlirPassDumpParser.process(deepCopy(rawMlirDump), {}, optPipelineOptions); + + expect(Object.keys(result)).toContain('add_kernel'); + expect(result['add_kernel'].length).toBe(2); + + // Verify the passes are correctly identified + expect(result['add_kernel'][0].name).toBe("Inliner (inline) ('builtin.module' operation)"); + expect(result['add_kernel'][1].name).toBe("Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)"); + + // Verify the IR changes are correctly detected + expect(result['add_kernel'][0].irChanged).toBe(false); + expect(result['add_kernel'][1].irChanged).toBe(true); + }); +}); diff --git a/types/languages.interfaces.ts b/types/languages.interfaces.ts index 976168207..51ad8a919 100644 --- a/types/languages.interfaces.ts +++ b/types/languages.interfaces.ts @@ -101,6 +101,7 @@ export type LanguageKey = | 'swift' | 'tablegen' | 'toit' + | 'triton' | 'typescript' | 'v' | 'vala'