Add Triton language and compiler (#7919)

Close #5530. Infra:
https://github.com/compiler-explorer/infra/pull/1711. Previous work by
@siboehm at #5531


## Summary

This pull request introduces support for the
[Triton](https://github.com/triton-lang/triton) language, a Python-based
DSL for writing highly efficient GPU kernels.


- [x] **New Language Support**: I've added comprehensive support for the
Triton programming language, allowing users to compile and inspect
Triton kernels within Compiler Explorer. (c.f.,
`lib/compilers/triton.ts`)
- [x] **Python Wrapper for Compilation**: A new Python wrapper script
(`triton_wrapper.py`) has been introduced to manage Triton compilation,
patching its behavior to dump compiled kernels and intermediate
representations without requiring actual execution, and consolidating
the output for Compiler Explorer.
- [x] **Device Assembly View**: Enables viewing of generated device
assembly code (e.g., PTX, AMDGCN) and various intermediate
representations (MLIR, LLVM IR) produced by the Triton compiler.
- [x] **MLIR Parsing**: New parsers (`asm-parser-mlir.ts` and
`mlir-pass-dump-parser.ts`) have been added to correctly interpret and
display MLIR assembly and optimization pass dumps, including source
location information.
- [x] **Multi-Version & Multi-Backend Support**: Painstakingly includes
all 8 versions (from 2.2.0 to 3.3.1) of Triton that supports Python
3.12. Supports both CUDA and HIP backend for Triton 3.

## Screenshots

Source and assembly:
<img width="1354" height="789" alt="image"
src="https://github.com/user-attachments/assets/c29650ff-2073-40e0-a9e6-ff8377094b5e"
/>



Device view for MLIR and LLVM IR: 
<img width="1402" height="670" alt="image"
src="https://github.com/user-attachments/assets/43dd5c68-ca78-41b1-9865-e97ffe3ef73c"
/>

Opt pipeline viewer:
<img width="1408" height="668" alt="image"
src="https://github.com/user-attachments/assets/429eef8c-aaac-4781-aafa-39ef0ffc7241"
/>

Diff of TTIR in Triton 3.3.1 vs 2.3.0:
<img width="1580" height="726" alt="image"
src="https://github.com/user-attachments/assets/a928c893-dd9a-4c3a-a048-14046e56a14c"
/>

CUDA & HIP:
<img width="1596" height="800" alt="image"
src="https://github.com/user-attachments/assets/c18800c3-cfad-4e5e-96de-ba92c9f236ea"
/>

## Implementation Details (and Notes for Reviewers)

- For Device Assembly View, I Implemented `MlirAsmParser` for parsing
MLIR assembly. Technically MLIR is not an assembly language, but there
is no better choice to make the source line map work w/ device view.
- I Implemented `MlirPassDumpParser` for processing MLIR optimization
pass dumps. I tried to subclass `LlvmPassDumpParser`, but they turn out
to be too different to worth doing it.
- `LlvmPassDumpParser` made some assumptions that do not hold true for
MLIR passed. Some effort is put to make sure that the passes are
properly diff-ed, since some passes can run multiple times and also
sometimes pass can be nested (i.e., some number of `before`s followed by
some number of `after`s)
- A lot of effort is put into `patch_triton` to make sure that the we
only compile the kernel without actually running it, and that needs to
work across all the versions we support.

## Steps to Run Locally

1. Clone https://github.com/ShawnZhong/compiler-explorer-infra.git
2. Install Triton to `/opt/compiler-explorer/triton`:
```sh
$ cd compiler-explorer-infra
$ ./bin/ce_install install triton
$ ls /opt/compiler-explorer/triton
# v2.2.0  v2.3.0  v2.3.1  v3.0.0  v3.1.0  v3.2.0  v3.3.0  v3.3.1
```
3. Clone https://github.com/ShawnZhong/compiler-explorer.git and
checkout branch `triton`
4. Run Compiler Explorer
```sh
make EXTRA_ARGS='--language triton' dev
```
5. Enjoy

---------

Co-authored-by: Matt Godbolt <matt@godbolt.org>
This commit is contained in:
Shawn Zhong
2025-07-30 08:15:28 -07:00
committed by GitHub
parent c7056b31eb
commit 8befc91a79
14 changed files with 1401 additions and 0 deletions

6
.github/labeler.yml vendored
View File

@@ -414,6 +414,12 @@
- 'etc/config/toit.*.properties'
- 'static/modes/toit-mode.ts'
'lang-triton':
- changed-files:
- any-glob-to-any-file:
- 'lib/compilers/triton.ts'
- 'etc/config/triton.*.properties'
'lang-typescript':
- changed-files:
- any-glob-to-any-file:

View File

@@ -0,0 +1,52 @@
compilers=&triton_nvidia:&triton_amd
defaultCompiler=triton_nvidia_331
compilerType=triton
interpreted=true
supportsBinary=false
supportsExecute=false
isSemVer=true
notification=Experimental Triton support on Compiler Explorer. For tutorials, bugs reports, and feature requests, please visit <a href="https://github.com/ShawnZhong/compiler-explorer-triton">here</a>.
group.triton_nvidia.compilers=triton_nvidia_331:triton_nvidia_330:triton_nvidia_320:triton_nvidia_310:triton_nvidia_300:triton_nvidia_231:triton_nvidia_230
group.triton_nvidia.groupName=Triton (Nvidia)
group.triton_nvidia.options=--backend cuda --arch 90 --warp_size 32
compiler.triton_nvidia_331.name=Triton 3.3.1 (Nvidia)
compiler.triton_nvidia_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
compiler.triton_nvidia_330.name=Triton 3.3.0 (Nvidia)
compiler.triton_nvidia_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
compiler.triton_nvidia_320.name=Triton 3.2.0 (Nvidia)
compiler.triton_nvidia_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
compiler.triton_nvidia_310.name=Triton 3.1.0 (Nvidia)
compiler.triton_nvidia_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
compiler.triton_nvidia_300.name=Triton 3.0.0 (Nvidia)
compiler.triton_nvidia_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3
compiler.triton_nvidia_231.name=Triton 2.3.1 (Nvidia)
compiler.triton_nvidia_231.exe=/opt/compiler-explorer/triton/v2.3.1/bin/python3
compiler.triton_nvidia_230.name=Triton 2.3.0 (Nvidia)
compiler.triton_nvidia_230.exe=/opt/compiler-explorer/triton/v2.3.0/bin/python3
group.triton_amd.compilers=triton_amd_331:triton_amd_330:triton_amd_320:triton_amd_310:triton_amd_300
group.triton_amd.groupName=Triton (AMD)
group.triton_amd.options=--backend hip --arch gfx942 --warp_size 32
compiler.triton_amd_331.name=Triton 3.3.1 (AMD)
compiler.triton_amd_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
compiler.triton_amd_330.name=Triton 3.3.0 (AMD)
compiler.triton_amd_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
compiler.triton_amd_320.name=Triton 3.2.0 (AMD)
compiler.triton_amd_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
compiler.triton_amd_310.name=Triton 3.1.0 (AMD)
compiler.triton_amd_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
compiler.triton_amd_300.name=Triton 3.0.0 (AMD)
compiler.triton_amd_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3

View File

@@ -0,0 +1,52 @@
compilers=&triton_nvidia:&triton_amd
defaultCompiler=triton_nvidia_331
compilerType=triton
interpreted=true
supportsBinary=false
supportsExecute=false
isSemVer=true
notification=Experimental Triton support on Compiler Explorer. For tutorials, bugs reports, and feature requests, please visit <a href="https://github.com/ShawnZhong/compiler-explorer-triton">here</a>.
group.triton_nvidia.compilers=triton_nvidia_331:triton_nvidia_330:triton_nvidia_320:triton_nvidia_310:triton_nvidia_300:triton_nvidia_231:triton_nvidia_230
group.triton_nvidia.groupName=Triton (Nvidia)
group.triton_nvidia.options=--backend cuda --arch 90 --warp_size 32
compiler.triton_nvidia_331.name=Triton 3.3.1 (Nvidia)
compiler.triton_nvidia_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
compiler.triton_nvidia_330.name=Triton 3.3.0 (Nvidia)
compiler.triton_nvidia_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
compiler.triton_nvidia_320.name=Triton 3.2.0 (Nvidia)
compiler.triton_nvidia_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
compiler.triton_nvidia_310.name=Triton 3.1.0 (Nvidia)
compiler.triton_nvidia_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
compiler.triton_nvidia_300.name=Triton 3.0.0 (Nvidia)
compiler.triton_nvidia_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3
compiler.triton_nvidia_231.name=Triton 2.3.1 (Nvidia)
compiler.triton_nvidia_231.exe=/opt/compiler-explorer/triton/v2.3.1/bin/python3
compiler.triton_nvidia_230.name=Triton 2.3.0 (Nvidia)
compiler.triton_nvidia_230.exe=/opt/compiler-explorer/triton/v2.3.0/bin/python3
group.triton_amd.compilers=triton_amd_331:triton_amd_330:triton_amd_320:triton_amd_310:triton_amd_300
group.triton_amd.groupName=Triton (AMD)
group.triton_amd.options=--backend hip --arch gfx942 --warp_size 32
compiler.triton_amd_331.name=Triton 3.3.1 (AMD)
compiler.triton_amd_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
compiler.triton_amd_330.name=Triton 3.3.0 (AMD)
compiler.triton_amd_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
compiler.triton_amd_320.name=Triton 3.2.0 (AMD)
compiler.triton_amd_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
compiler.triton_amd_310.name=Triton 3.1.0 (AMD)
compiler.triton_amd_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
compiler.triton_amd_300.name=Triton 3.0.0 (AMD)
compiler.triton_amd_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3

View File

@@ -0,0 +1,254 @@
import argparse
import importlib.util
import inspect
import json
import os
from pathlib import Path
from typing import Dict, Optional, Union
from unittest.mock import MagicMock
import torch
import triton
from torch._subclasses.fake_tensor import FakeTensorMode
class MockCacheManager(triton.runtime.cache.CacheManager):
"""
A mock cache manager that dumps the intermediate files to a given output path.
There are various ways to dump the intermediate files:
1. The most obvious way is to use the `TRITON_KERNEL_DUMP` & ``TRITON_DUMP_DIR`
environment variables. e.g.,
os.environ["TRITON_KERNEL_DUMP"] = "1"
os.environ["TRITON_DUMP_DIR"] = str(output_dir)
However, `TRITON_DUMP_DIR` is introduced in Triton v3.2.0 at
https://github.com/triton-lang/triton/commit/ca469d7b6b6def316b5f5ee6ad2bd19dcb840bd8,
and thus not available in older versions.
2. The second attempt is to patch the `default_cache_dir` function. e.g.,
triton.runtime.cache.default_cache_dir = MagicMock(return_value=output_dir)
This is a bit hacky, and less flexible in terms of controlling the file output.
(In fact, Triton dumps the compiled kernels to a folder with a random name.)
3. Another option is to use various hooks in Triton. e.g.,
triton.knobs.runtime.{jit_post_compile_hook,launch_enter_hook}
JITFunction.{compiled_hook,cache_hook}
This approach is taken by TritonParse(https://github.com/pytorch-labs/tritonparse),
but it does not support older versions of Triton prior to the following commits:
https://github.com/triton-lang/triton/commit/0e9267202532ed1709dcc12c636220cf239dc377,
https://github.com/triton-lang/triton/commit/850525276426fb9814399a8e0ee8fdf744229b02.
4. The current apporach is to mock a `CacheManager` class. This is the most flexible
approach, and works for all versions of Triton.
"""
output_file: Path
def __init__(self, key, override=False, dump=False):
self.dump = dump
# filename -> data
self.files = {}
# filename -> group dict
self.groups = {}
# current stage for a given kernel
self.stage = 0
def get_file(self, filename) -> Optional[str]:
return self.files.get(filename, None)
def put(self, data, filename, binary=True) -> str:
name = Path(filename).stem
suffix = Path(filename).suffix
binary = isinstance(data, bytes)
if not binary:
data = str(data)
# Write the final file to the output file, so that we can view it in the default assembly view.
if suffix in (".ptx", ".amdgcn"):
with open(MockCacheManager.output_file, "a") as f:
f.write(data)
f.write("\n\n")
# Write intermediate files to the output file, so that we can see them in the Device View.
self.stage += 1
if suffix == ".json":
path = MockCacheManager.output_file.parent / filename
with open(path, "w") as fout:
json.dump(json.loads(data), fout, indent=2)
else:
path = (
MockCacheManager.output_file.parent
/ f"{name} [stage {self.stage}]{suffix}"
)
if not binary:
with open(path, "w") as fout:
fout.write(data)
elif suffix == ".cubin":
try:
# The try-catch is needed because `disasm` was broken in Triton v3.0.0 and v3.1.0. See
# https://github.com/triton-lang/triton/commit/f424f656b3528c47d8c48126cdccafca29e536ae
from triton.tools import disasm
with open(path.with_suffix(".sass"), "w") as fout:
fout.write(disasm.get_sass(data))
except Exception:
pass
# Write the file to the "cache"
self.files[filename] = data
return filename
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
self.groups.get(filename, None)
def put_group(self, filename: str, group: Dict[str, str]):
self.groups[filename] = group
def setup_triton(
output_file: Path,
opt_pipeline_file: Path,
backend: str,
arch: Union[int, str],
warp_size: int,
):
"""
Patch Triton to dump the compiled kernels to output dir without actually running them.
This is needed because
1. Triton does not easily support such use case. There exists an AOT compiler at
https://github.com/triton-lang/triton/blob/main/python/triton/tools/compile.py,
but it requires a bunch of boilerplate code and also requires additional user
input to specify the kernel name, signature, etc.
2. Even if Triton adds such support, older versions of Triton (e.g., v2.3.x) still
requirs such patching to work.
This function is a collection of hacks. It has been tested to work with Triton
2.3.0, 2.3.1, 3.0.0, 3.1.0, 3.2.0, 3.3.0, 3.3.1.
"""
os.environ["TRITON_ALWAYS_COMPILE"] = "1"
if opt_pipeline_file:
os.environ["MLIR_ENABLE_DUMP"] = "1"
# Supported in Triton v3.3.0 and later since
# https://github.com/triton-lang/triton/commit/3d7d9e33e7e4cba17dc366d207af2c657bd4fbd1
os.environ["MLIR_DUMP_PATH"] = str(opt_pipeline_file)
else:
# Disable dumping other files w/ opt_pipeline_file since they race with each other
os.environ["TRITON_CACHE_MANAGER"] = "__main__:MockCacheManager"
MockCacheManager.output_file = output_file
# Usually, Triton compiles and run a kernel when we call `kernel[grid](args)`.
# However, we want to dump the compiled kernel without actually running it.
# The class `CompiledKernel` represents a handle to a compiled kernel,
# ready to be launched. We patch it to be a no-op.
triton.compiler.compiler.CompiledKernel = MagicMock()
# We mock a GPU driver to avoid the need to initialize CUDA/ROCm.
# The driver is only used in runtime instead of compile time,
# so it's safe to do this.
def get_current_target():
try:
from triton.compiler.compiler import GPUTarget
return GPUTarget(backend=backend, arch=arch, warp_size=warp_size)
except ImportError:
# For Triton v2.3.x, we don't have GPUTarget
return (backend, arch)
mockGPUDriver = MagicMock(
get_current_target=get_current_target,
get_benchmarker=lambda: MagicMock(return_value=[0.0]),
)
# Set the active driver to the mocked one.
# `DriverConfig` and `triton.runtime.driver.set_active` is introduced in Triton v3.0.0 at
# https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4
# For older versions of Triton, we directly assign to the `_obj` field of `LazyProxy`.
try:
from triton.runtime.driver import DriverConfig
triton.runtime.driver.set_active(mockGPUDriver)
except ImportError:
triton.runtime.driver._obj = mockGPUDriver
# For Triton v2.3.x, there are some driver code that goes into
# the generic code path, so we need to patch it as well.
try:
from triton.compiler.backends.cuda import CUDABackend
CUDABackend.make_launcher_stub = MagicMock()
except ImportError:
pass
def main(
input_file: Path,
output_file: Path,
opt_pipeline_file: Path,
backend: str,
arch: Union[int, str],
warp_size: int,
):
setup_triton(output_file, opt_pipeline_file, backend, arch, warp_size)
# Run the script by importing it as a module
spec = importlib.util.spec_from_file_location("example", input_file)
module = importlib.util.module_from_spec(spec)
with FakeTensorMode():
# Use FakeTensor (developed during Dynamo) to avoid actually creating tensors
# See https://docs.pytorch.org/docs/stable/torch.compiler_fake_tensor.html
# Also set the data_ptr to 0 to avoid PyTorch warning and make alignment check happy
torch._subclasses.FakeTensor.data_ptr = MagicMock(return_value=0)
spec.loader.exec_module(module)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Triton wrapper")
parser.add_argument(
"input_file",
type=Path,
help="Path to the input Python file",
)
parser.add_argument(
"--output_file",
type=Path,
required=True,
help="Path to the output file",
)
parser.add_argument(
"--opt_pipeline_file",
type=Path,
help="Path to the output opt pipeline file",
)
parser.add_argument(
"--backend",
type=str,
default="cuda",
choices=["cuda", "hip"],
)
parser.add_argument(
"--arch",
type=str,
default=None, # Default value set later based on backend
)
parser.add_argument(
"--warp_size",
type=int,
default=32,
)
args = parser.parse_args()
# Set some sane defaults for the arch
if args.arch is None:
if args.backend == "cuda":
args.arch = 90
elif args.backend == "hip":
args.arch = "gfx942"
# Triton expects the arch to be an int for CUDA and a string for HIP
if args.backend == "cuda":
args.arch = int(args.arch)
main(**vars(args))

View File

@@ -0,0 +1,28 @@
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
x = torch.rand(1024)
y = torch.rand(1024)
output = torch.empty(1024)
n_elements = output.numel()
add_kernel[(1,)](x, y, output, n_elements, BLOCK_SIZE=1024)

View File

@@ -154,6 +154,7 @@ export {TenDRACompiler} from './tendra.js';
export {TIC2000} from './tic2000.js';
export {TinyCCompiler} from './tinyc.js';
export {ToitCompiler} from './toit.js';
export {TritonCompiler} from './triton.js';
export {TurboCCompiler} from './turboc.js';
export {TypeScriptNativeCompiler} from './typescript-native.js';
export {VCompiler} from './v.js';

201
lib/compilers/triton.ts Normal file
View File

@@ -0,0 +1,201 @@
// Copyright (c) 2025, Compiler Explorer Authors
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
import * as fs from 'node:fs/promises';
import Path from 'node:path';
import type {CompilationInfo, CompilationResult} from '../../types/compilation/compilation.interfaces.js';
import type {
OptPipelineBackendOptions,
OptPipelineOutput,
} from '../../types/compilation/opt-pipeline-output.interfaces.js';
import type {PreliminaryCompilerInfo} from '../../types/compiler.interfaces.js';
import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js';
import {BaseCompiler} from '../base-compiler.js';
import {CompilationEnvironment} from '../compilation-env.js';
import type {IAsmParser} from '../parsers/asm-parser.interfaces.js';
import {AmdgpuAsmParser} from '../parsers/asm-parser-amdgpu.js';
import {MlirAsmParser} from '../parsers/asm-parser-mlir.js';
import {PTXAsmParser} from '../parsers/asm-parser-ptx.js';
import {SassAsmParser} from '../parsers/asm-parser-sass.js';
import {MlirPassDumpParser} from '../parsers/mlir-pass-dump-parser.js';
import {parseOutput, resolvePathFromAppRoot} from '../utils.js';
import {BaseParser} from './argument-parsers.js';
export class TritonCompiler extends BaseCompiler {
private compilerWrapperPath: string;
static get key() {
return 'triton';
}
parserMap: Record<string, IAsmParser>;
mlirPassDumpParser: MlirPassDumpParser;
constructor(compilerInfo: PreliminaryCompilerInfo, env: CompilationEnvironment) {
super(compilerInfo, env);
this.compilerWrapperPath =
this.compilerProps('compilerWrapper', '') || resolvePathFromAppRoot('etc', 'scripts', 'triton_wrapper.py');
// Enable the Opt Pipeline view
this.compiler.optPipeline = {};
// Used to parse the output of the opt pipeline
this.mlirPassDumpParser = new MlirPassDumpParser(this.compilerProps);
// Enable the Device Viewer
this.compiler.supportsDeviceAsmView = true;
// Define parsers for the different output files displayed in the Device Viewer
const sassAsmParser = new SassAsmParser(this.compilerProps);
const ptxAsmParser = new PTXAsmParser(this.compilerProps);
const amdgpuAsmParser = new AmdgpuAsmParser();
const mlirAsmParser = new MlirAsmParser();
this.parserMap = {
'.ttir': mlirAsmParser,
'.ttgir': mlirAsmParser,
'.ptx': ptxAsmParser,
'.sass': sassAsmParser,
'.source': mlirAsmParser,
'.amdgcn': amdgpuAsmParser,
'.llir': mlirAsmParser,
'.json': sassAsmParser,
};
if (compilerInfo.group == 'triton_amd') {
this.asm = amdgpuAsmParser;
} else if (compilerInfo.group == 'triton_nvidia') {
this.asm = ptxAsmParser;
}
}
override optionsForFilter(filters: ParseFiltersAndOutputOptions, outputFilename: string): string[] {
// See etc/scripts/triton_wrapper.py for the options
return ['-I', this.compilerWrapperPath, '--output_file', outputFilename];
}
override getArgumentParserClass() {
return BaseParser;
}
override async extractDeviceCode(
result: CompilationResult,
filters: ParseFiltersAndOutputOptions,
compilationInfo: CompilationInfo,
) {
const devices = {...result.devices};
const {dirPath} = result;
if (!dirPath) {
return result;
}
// Extract the device code from the output directory
const files = await fs.readdir(dirPath);
await Promise.all(
files.map(async filename => {
const ext = Path.extname(filename);
const parser = this.parserMap[ext];
if (!parser) {
return;
}
// Read the file
const data = await fs.readFile(Path.join(dirPath, filename), 'utf8');
// Parse the assembly with line numbers
let device;
if (ext === '.llir') {
device = await this.llvmIr.process(data, {
filterDebugInfo: false,
filterIRMetadata: false,
filterAttributes: false,
filterComments: false,
noDiscardValueNames: false,
demangle: false,
});
} else {
device = await parser.process(data, filters);
}
Object.assign(devices, {[filename]: device});
}),
);
result.devices = devices;
return result;
}
override async generateOptPipeline(
inputFilename: string,
options: string[],
filters: ParseFiltersAndOutputOptions,
optPipelineOptions: OptPipelineBackendOptions,
): Promise<OptPipelineOutput | undefined> {
// Call the script to generate the opt pipeline
const execOptions = this.getDefaultExecOptions();
const outputFilename = Path.join(Path.dirname(inputFilename), 'opt_pipeline.txt');
const optOptions = [...options, '--opt_pipeline_file', outputFilename];
const compileStart = performance.now();
await this.runCompiler(this.compiler.exe, optOptions, inputFilename, execOptions);
const compileEnd = performance.now();
// Read the output file and parse it
try {
const rawText = await fs.readFile(outputFilename, {encoding: 'utf8'});
const lines = parseOutput(rawText);
const parseStart = performance.now();
const llvmOptPipeline = await this.mlirPassDumpParser.process(lines, filters, optPipelineOptions);
const parseEnd = performance.now();
return {
results: llvmOptPipeline,
compileTime: compileEnd - compileStart,
parseTime: parseEnd - parseStart,
};
} catch (e: any) {
return {
error: e.toString(),
results: {},
compileTime: compileEnd - compileStart,
};
}
}
override getDefaultFilters() {
return {
intel: false,
commentOnly: false,
directives: false,
labels: false,
optOutput: true,
binary: false,
execute: false,
demangle: false,
libraryCode: false,
trim: false,
binaryObject: false,
debugCalls: false,
};
}
}

View File

@@ -894,6 +894,17 @@ const definitions: Record<LanguageKey, LanguageDefinition> = {
previewFilter: null,
monacoDisassembly: null,
},
triton: {
name: 'Triton',
monaco: 'python',
extensions: ['.py'],
alias: [],
logoFilename: 'triton.png',
logoFilenameDark: null,
formatter: null,
previewFilter: null,
monacoDisassembly: null,
},
typescript: {
name: 'TypeScript Native',
monaco: 'typescript',

View File

@@ -0,0 +1,146 @@
// Copyright (c) 2025, Compiler Explorer Authors
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
import {AsmResultSource, ParsedAsmResult, ParsedAsmResultLine} from '../../types/asmresult/asmresult.interfaces.js';
import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js';
import * as utils from '../utils.js';
import {AsmParser} from './asm-parser.js';
export class MlirAsmParser extends AsmParser {
protected locDefRegex: RegExp;
protected locDefUnknownRegex: RegExp;
protected locRefRegex: RegExp;
protected locRefRegexReplace: RegExp;
protected inlineLocRegex: RegExp;
protected inlineLocRegexReplace: RegExp;
constructor() {
super();
// Match location definitions like #loc1 = loc("/path/to/file":line:column)
this.locDefRegex = /^#(\w+)\s*=\s*loc\("([^"]+)":(\d+):(\d+)\)/;
// Match location definitions like #loc1 = loc(unknown)
this.locDefUnknownRegex = /^#(\w+)\s*=\s*loc\(unknown\)/;
// Match location references like loc(#loc1)
this.locRefRegex = /\s*loc\(#(\w+)\)/;
this.locRefRegexReplace = new RegExp(this.locRefRegex.source, 'g');
// Match inline locations like loc("/path/to/file":line:column)
this.inlineLocRegex = /\s*loc\("([^"]+)":(\d+):(\d+)\)/;
this.inlineLocRegexReplace = new RegExp(this.inlineLocRegex.source, 'g');
}
override processAsm(asmResult: string, filters: ParseFiltersAndOutputOptions): ParsedAsmResult {
const startTime = process.hrtime.bigint();
const asm: ParsedAsmResultLine[] = [];
const asmLines = utils.splitLines(asmResult);
const startingLineCount = asmLines.length;
// First pass: extract all location definitions
const locationMap = new Map<string, AsmResultSource>();
for (const line of asmLines) {
const locMatch = line.match(this.locDefRegex);
if (locMatch) {
const locId = locMatch[1];
const file = locMatch[2];
const lineNum = Number.parseInt(locMatch[3], 10);
const column = Number.parseInt(locMatch[4], 10);
locationMap.set(locId, {
file: utils.maskRootdir(file),
line: lineNum,
column: column,
mainsource: true,
});
}
}
// Second pass: process each line and associate with source information
for (const line of asmLines) {
// Skip location definition lines
if (this.locDefRegex.test(line) || this.locDefUnknownRegex.test(line)) {
continue;
}
// Apply filters if needed
let processedLine = line;
if (filters.trim) {
processedLine = processedLine.trim();
}
if (filters.commentOnly && processedLine.trim().startsWith('//')) {
continue;
}
// Find source information from location references
let source: AsmResultSource | null = null;
// Check for location references like loc(#loc1)
const locRefMatch = line.match(this.locRefRegex);
if (locRefMatch) {
const locId = locRefMatch[1];
source = locationMap.get(locId) || null;
// Remove location reference from the displayed text
processedLine = processedLine.replace(this.locRefRegexReplace, '');
} else {
// Check for inline locations like loc("/path/to/file":line:column)
const inlineLocMatch = line.match(this.inlineLocRegex);
if (inlineLocMatch) {
const file = inlineLocMatch[1];
const lineNum = Number.parseInt(inlineLocMatch[2], 10);
const column = Number.parseInt(inlineLocMatch[3], 10);
source = {
file: utils.maskRootdir(file),
line: lineNum,
column: column,
mainsource: true,
};
}
// Remove inline location from the displayed text
processedLine = processedLine.replace(this.inlineLocRegexReplace, '');
}
// Add the line to the result
asm.push({
text: processedLine,
source: source,
labels: [],
});
}
const endTime = process.hrtime.bigint();
return {
asm: asm,
labelDefinitions: {},
languageId: 'mlir',
parsingTime: utils.deltaTimeNanoToMili(startTime, endTime),
filteredCount: startingLineCount - asm.length,
};
}
}

View File

@@ -0,0 +1,374 @@
// Copyright (c) 2025, Compiler Explorer Authors
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
import {
OptPipelineBackendOptions,
OptPipelineResults,
Pass,
} from '../../types/compilation/opt-pipeline-output.interfaces.js';
import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js';
import {ResultLine} from '../../types/resultline/resultline.interfaces.js';
import {assert} from '../assert.js';
import {PropertyGetter} from '../properties.interfaces.js';
// Helper function to extract pass name from header
function extractPassName(header: string): string {
if (header.startsWith('IR Dump Before ')) {
return header.slice('IR Dump Before '.length);
}
if (header.startsWith('IR Dump After ')) {
let passName = header.slice('IR Dump After '.length);
// Handle invalidated passes
if (passName.endsWith(' (invalidated)')) {
passName = passName.slice(0, passName.length - ' (invalidated)'.length);
}
return passName;
}
assert(false, 'Unexpected pass header', header);
}
// Ir Dump for a pass with raw lines
type PassDump = {
header: string;
affectedFunction: string | undefined;
lines: ResultLine[];
};
// Ir Dump for a pass with raw lines broken into affected functions
type SplitPassDump = {
header: string;
functions: Record<string, ResultLine[]>;
};
export class MlirPassDumpParser {
locationDefine: RegExp;
locationReference: RegExp;
irDumpHeader: RegExp;
functionDefine: RegExp;
moduleDefine: RegExp;
functionEnd: RegExp;
moduleEnd: RegExp;
constructor(compilerProps: PropertyGetter) {
// Location definitions: #loc0 = loc(...)
this.locationDefine = /^#loc\d* = loc\(.+\)$/;
// Location references: loc(#loc0), loc("/app/example.py":19:0), loc(unknown)
this.locationReference = /\s*loc\([^)]*\)/g;
// MLIR dump headers look like "// -----// IR Dump Before/After XYZ (xyz) ('operation-type' operation: @function_name) //----- //"
this.irDumpHeader = /^\/\/ -----\/\/ (IR Dump (?:Before|After) .+) \/\/----- \/\/$/;
// MLIR function definitions look like "func.func @function_name(...) {"
// or "tt.func public @function_name(...) {"
this.functionDefine = /^\s*(\w+\.func\s+(?:\w+\s+)?@(\w+).*\{)$/;
// MLIR module definitions look like "module {"
this.moduleDefine = /^\s*(module\s*\{)$/;
// Functions end with a closing brace
this.functionEnd = /^\s*\}\s*$/;
// Modules end with a closing brace
this.moduleEnd = /^\s*\}\s*$/;
}
breakdownOutputIntoPassDumps(ir: ResultLine[]) {
// break down output by "// -----// IR Dump Before/After XYZ //----- //" markers
const raw_passes: PassDump[] = [];
let pass: PassDump | null = null;
let lastWasBlank = false; // skip duplicate blank lines
for (const line of ir) {
const irMatch = line.text.match(this.irDumpHeader);
if (irMatch) {
if (pass !== null) {
raw_passes.push(pass);
}
const headerText = irMatch[1];
pass = {
header: headerText,
affectedFunction: undefined,
lines: [],
};
lastWasBlank = true; // skip leading newlines after the header
} else {
if (pass === null) continue;
if (line.text.trim() === '') {
if (!lastWasBlank) {
pass.lines.push(line);
}
lastWasBlank = true;
} else {
pass.lines.push(line);
lastWasBlank = false;
}
}
}
if (pass !== null) {
raw_passes.push(pass);
}
return raw_passes;
}
breakdownPassDumpsIntoFunctions(dump: PassDump) {
// Simplified version based on the assumption that:
// 1. Functions always live inside a single module
// 2. A single module always has a single function in it
// 3. We use the name of the function to show the entire module
const pass: SplitPassDump = {
header: dump.header,
functions: {},
};
// Find the function name inside the module
let functionName: string | null = null;
for (const line of dump.lines) {
const funcMatch = line.text.match(this.functionDefine);
if (funcMatch) {
functionName = funcMatch[2];
break;
}
}
// If we found a function name, use it; otherwise use "module"
const name = functionName || 'module';
pass.functions[name] = dump.lines;
return pass;
}
breakdownIntoPassDumpsByFunction(passDumps: SplitPassDump[]) {
// Currently we have an array of passes with a map of functions altered in each pass
// We want to transpose to a map of functions with an array of passes on the function
const passDumpsByFunction: Record<string, PassDump[]> = {};
for (const pass of passDumps) {
const {header, functions} = pass;
for (const [function_name, lines] of Object.entries(functions)) {
if (!(function_name in passDumpsByFunction)) {
passDumpsByFunction[function_name] = [];
}
passDumpsByFunction[function_name].push({
header,
affectedFunction: undefined,
lines,
});
}
}
return passDumpsByFunction;
}
associateFullDumpsWithFunctions(passDumps: PassDump[]) {
// Currently we have an array of passes that'll have target annotations
const passDumpsByFunction: Record<string, PassDump[]> = {};
// First figure out what all the functions are
for (const pass of passDumps) {
if (pass.affectedFunction) {
passDumpsByFunction[pass.affectedFunction] = [];
}
}
// Add a special entry for the full module
passDumpsByFunction['<Full Module>'] = [];
for (const pass of passDumps) {
const {header, affectedFunction, lines} = pass;
if (affectedFunction) {
// This pass affects a specific function
assert(affectedFunction in passDumpsByFunction);
// Add to both the specific function and the full module view
[passDumpsByFunction[affectedFunction], passDumpsByFunction['<Full Module>']].map(entry =>
entry.push({
header: `${header} (${affectedFunction})`,
affectedFunction,
lines,
}),
);
} else {
// This pass applies to everything
for (const [_, entry] of Object.entries(passDumpsByFunction)) {
entry.push({
header,
affectedFunction: undefined,
lines,
});
}
}
}
return passDumpsByFunction;
}
isIrChanged(before: ResultLine[], after: ResultLine[]) {
if (before.length !== after.length) {
return true;
}
for (let i = 0; i < before.length; i++) {
if (before[i].text !== after[i].text) {
return true;
}
}
return false;
}
matchPassDumps(passDumpsByFunction: Record<string, PassDump[]>) {
// We have all the passes for each function, now we will go through each function and match the before/after
// dumps, handling the case where the same pass might occur multiple times
const final_output: OptPipelineResults = {};
for (const [function_name, passDumps] of Object.entries(passDumpsByFunction)) {
const passes: Pass[] = [];
// Use a stack of "Before" passes
const beforePasses: PassDump[] = [];
// Collect all "Before" passes in order
for (const dump of passDumps) {
if (dump.header.startsWith('IR Dump Before ')) {
beforePasses.push(dump);
}
}
// Process "After" passes and match with "Before" passes
for (const dump of passDumps) {
if (dump.header.startsWith('IR Dump After ')) {
const afterPassName = extractPassName(dump.header);
const pass: Pass = {
name: afterPassName,
machine: false,
after: dump.lines,
before: [],
irChanged: true,
};
// Find matching "Before" pass by name
for (let i = 0; i < beforePasses.length; i++) {
const beforePassName = extractPassName(beforePasses[i].header);
if (beforePassName === afterPassName) {
// Found a match, use it and remove from the stack
pass.before = beforePasses[i].lines;
// Check for equality
pass.irChanged = this.isIrChanged(pass.before, pass.after);
// Remove the matched "Before" pass
beforePasses.splice(i, 1);
break;
}
}
passes.push(pass);
}
}
// If we only have before passes (no after passes), diff between consecutive before passes
// This happened in Triton since it sets enableIRPrinting(printAfterOnlyOnFailure=false)
if (passes.length === 0) {
for (let i = 0; i < beforePasses.length - 1; i++) {
const isLast = i === beforePasses.length - 1;
const passName = extractPassName(beforePasses[i].header);
const before = beforePasses[i].lines;
const after = isLast ? beforePasses[i].lines : beforePasses[i + 1].lines;
const irChanged = isLast ? false : this.isIrChanged(before, after);
const pass: Pass = {
name: passName,
machine: false,
before: before,
after: after,
irChanged: irChanged,
};
passes.push(pass);
}
} else {
// Handle any remaining "Before" passes that don't have corresponding "After" passes
for (const beforeDump of beforePasses) {
const passName = extractPassName(beforeDump.header);
const pass: Pass = {
name: passName,
machine: false,
before: beforeDump.lines,
after: [],
irChanged: true, // Assume changed since there's no "After" to compare with
};
passes.push(pass);
}
}
final_output[function_name] = passes;
}
return final_output;
}
breakdownOutput(ir: ResultLine[], optPipelineOptions: OptPipelineBackendOptions) {
// break down output by "// -----// IR Dump Before/After XYZ //----- //" markers
const raw_passes = this.breakdownOutputIntoPassDumps(ir);
if (optPipelineOptions.fullModule) {
const passDumpsByFunction = this.associateFullDumpsWithFunctions(raw_passes);
// Match before / after pass dumps and we're done
return this.matchPassDumps(passDumpsByFunction);
}
// Further break down by functions in each dump
const passDumps = raw_passes.map(this.breakdownPassDumpsIntoFunctions.bind(this));
// Transform array of passes containing multiple functions into a map from functions to arrays of passes on
// those functions
const passDumpsByFunction = this.breakdownIntoPassDumpsByFunction(passDumps);
// Match before / after pass dumps and we're done
return this.matchPassDumps(passDumpsByFunction);
}
applyIrFilters(ir: ResultLine[]) {
return ir
.filter(line => line.text.match(this.locationDefine) === null)
.map(line => ({
...line,
text: line.text.replace(this.locationReference, ''),
}));
}
process(output: ResultLine[], _: ParseFiltersAndOutputOptions, optPipelineOptions: OptPipelineBackendOptions) {
// Crop out any junk before the pass dumps
const ir = output.slice(output.findIndex(line => this.irDumpHeader.test(line.text)));
const preprocessed_lines = this.applyIrFilters(ir);
return this.breakdownOutput(preprocessed_lines, optPipelineOptions);
}
}

BIN
public/logos/triton.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

View File

@@ -24,6 +24,7 @@
import {describe, expect, it} from 'vitest';
import {AsmParser} from '../lib/parsers/asm-parser.js';
import {MlirAsmParser} from '../lib/parsers/asm-parser-mlir.js';
import {PTXAsmParser} from '../lib/parsers/asm-parser-ptx.js';
describe('AsmParser tests', () => {
@@ -196,3 +197,64 @@ vprintf,
});
});
});
describe('MlirAsmParser tests', () => {
const parser = new MlirAsmParser();
describe('Location handling', () => {
it('should process MLIR with location references', () => {
const input = `
#loc = loc("<source>":7:0)
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("<source>":7:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("<source>":7:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("<source>":7:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("<source>":7:0)) attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("<source>":14:24)`;
const result = parser.processAsm(input, {});
// Verify location definitions are removed
expect(result.asm.find(line => line.text.includes('#loc = loc'))).toBeUndefined();
expect(result.asm.find(line => line.text.includes('#loc1 = loc'))).toBeUndefined();
expect(result.asm.find(line => line.text.includes('#loc2 = loc'))).toBeUndefined();
// Verify location references are removed from displayed text
expect(result.asm.find(line => line.text.includes('loc(#loc)'))).toBeUndefined();
expect(result.asm.find(line => line.text.includes('loc(#loc1)'))).toBeUndefined();
expect(result.asm.find(line => line.text.includes('loc(#loc2)'))).toBeUndefined();
// Verify inline locations are removed
expect(result.asm.find(line => line.text.includes('loc("<source>"'))).toBeUndefined();
// Verify source information is correctly associated
const programIdLine = result.asm.find(line => line.text.includes('tt.get_program_id'));
expect(programIdLine).toBeDefined();
expect(programIdLine?.source).toBeDefined();
expect(programIdLine?.source?.file).toBe('<source>');
expect(programIdLine?.source?.line).toBe(14);
expect(programIdLine?.source?.column).toBe(24);
// Verify unknown locations are not associated with source information
const constantLine = result.asm.find(line => line.text.includes('arith.constant'));
expect(constantLine).toBeDefined();
expect(constantLine?.source).toBeNull();
// Verify the structure is preserved
const moduleStartLine = result.asm.find(line => line.text.includes('module {'));
expect(moduleStartLine).toBeDefined();
const funcLine = result.asm.find(line => line.text.includes('tt.func public @add_kernel('));
expect(funcLine).toBeDefined();
expect(funcLine?.text?.includes('%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, ')).toBe(true);
expect(funcLine?.text?.includes('%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, ')).toBe(true);
expect(funcLine?.text?.includes('%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, ')).toBe(true);
expect(funcLine?.text?.includes('%arg3: i32 {tt.divisibility = 16 : i32})')).toBe(true);
const constLine = result.asm.find(line => line.text.includes('arith.constant 1024'));
expect(constLine).toBeDefined();
});
});
});

View File

@@ -0,0 +1,213 @@
// Copyright (c) 2025, Compiler Explorer Authors
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
import {beforeAll, describe, expect, it} from 'vitest';
import {MlirPassDumpParser} from '../lib/parsers/mlir-pass-dump-parser.js';
import * as properties from '../lib/properties.js';
function deepCopy<T>(obj: T): T {
return JSON.parse(JSON.stringify(obj));
}
describe('mlir-pass-dump-parser', () => {
let mlirPassDumpParser: MlirPassDumpParser;
beforeAll(() => {
const fakeProps = new properties.CompilerProps({mlir: {id: 'mlir'}}, properties.fakeProps({}));
const compilerProps = (fakeProps.get as any).bind(fakeProps, 'mlir');
mlirPassDumpParser = new MlirPassDumpParser(compilerProps);
});
const rawMlirDump = [
{
text: "// -----// IR Dump Before Inliner (inline) ('builtin.module' operation) //----- //",
},
{text: 'module {'},
{
text: ' tt.func public @add_kernel() attributes {noinline = false} {',
},
{text: ' %0 = tt.get_program_id x : i32 loc(#loc1)'},
{text: ' tt.return loc(#loc2)'},
{text: ' } loc(#loc)'},
{text: '} loc(#loc)'},
{text: '#loc = loc("/app/example.py":7:0)'},
{text: '#loc1 = loc("/app/example.py":8:24)'},
{text: '#loc2 = loc("/app/example.py":8:4)'},
{text: ''},
{text: ''},
{
text: "// -----// IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel) //----- //",
},
{text: 'module {'},
{
text: ' tt.func public @add_kernel() attributes {noinline = false} {',
},
{text: ' %0 = tt.get_program_id x : i32 loc(#loc1)'},
{text: ' tt.return loc(#loc2)'},
{text: ' } loc(#loc)'},
{text: '} loc(#loc)'},
{text: '#loc = loc("/app/example.py":7:0)'},
{text: '#loc1 = loc("/app/example.py":8:24)'},
{text: '#loc2 = loc("/app/example.py":8:4)'},
{text: ''},
{text: ''},
{
text: "// -----// IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation) //----- //",
},
{text: 'module {'},
{
text: ' tt.func public @add_kernel() attributes {noinline = false} {',
},
{text: ' tt.return loc(#loc1)'},
{text: ' } loc(#loc)'},
{text: '} loc(#loc)'},
{text: '#loc = loc("/app/example.py":7:0)'},
{text: '#loc1 = loc("/app/example.py":8:4)'},
{text: ''},
{text: ''},
];
it('should break down output into pass dumps', () => {
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
expect(passDumps.length).toBe(3);
expect(passDumps[0].header).toBe("IR Dump Before Inliner (inline) ('builtin.module' operation)");
expect(passDumps[1].header).toBe(
"IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)",
);
expect(passDumps[2].header).toBe(
"IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation)",
);
// Check that the first pass dump has the correct lines
expect(passDumps[0].lines.length).toBe(10);
expect(passDumps[0].lines[0].text).toBe('module {');
expect(passDumps[0].lines[1].text).toBe(' tt.func public @add_kernel() attributes {noinline = false} {');
});
it('should break down pass dumps into functions', () => {
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
const splitPassDump = mlirPassDumpParser.breakdownPassDumpsIntoFunctions(passDumps[0]);
expect(splitPassDump.header).toBe("IR Dump Before Inliner (inline) ('builtin.module' operation)");
expect(Object.keys(splitPassDump.functions)).toContain('add_kernel');
expect(splitPassDump.functions['add_kernel'].length).toBe(10);
});
it('should apply IR filters to remove location information', () => {
const filtered = mlirPassDumpParser.applyIrFilters(deepCopy(rawMlirDump.slice(0, 7)));
// Should filter out location references but keep the lines
expect(filtered.length).toBe(7);
expect(filtered[3].text).toBe(' %0 = tt.get_program_id x : i32');
expect(filtered[4].text).toBe(' tt.return');
expect(filtered[5].text).toBe(' }');
expect(filtered[6].text).toBe('}');
});
it('should break down output into pass dumps by function', () => {
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
const splitPassDumps = passDumps.map(dump => mlirPassDumpParser.breakdownPassDumpsIntoFunctions(dump));
const passDumpsByFunction = mlirPassDumpParser.breakdownIntoPassDumpsByFunction(splitPassDumps);
expect(Object.keys(passDumpsByFunction)).toContain('add_kernel');
expect(passDumpsByFunction['add_kernel'].length).toBe(3);
// Check that the function has all three passes
const headers = passDumpsByFunction['add_kernel'].map(dump => dump.header);
expect(headers).toContain("IR Dump Before Inliner (inline) ('builtin.module' operation)");
expect(headers).toContain("IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)");
expect(headers).toContain(
"IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation)",
);
});
it('should detect IR changes between passes', () => {
// Create two different IR dumps to test change detection
const before = [
{text: 'module {'},
{text: ' tt.func public @add_kernel() {'},
{text: ' %0 = tt.get_program_id x : i32'},
{text: ' tt.return'},
{text: ' }'},
{text: '}'},
];
const afterNoChange = deepCopy(before);
expect(mlirPassDumpParser.isIrChanged(before, afterNoChange)).toBe(false);
const afterWithChange = [
{text: 'module {'},
{text: ' tt.func public @add_kernel() {'},
{text: ' tt.return'}, // Line removed
{text: ' }'},
{text: '}'},
];
expect(mlirPassDumpParser.isIrChanged(before, afterWithChange)).toBe(true);
});
it('should match pass dumps and detect changes', () => {
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
const splitPassDumps = passDumps.map(dump => mlirPassDumpParser.breakdownPassDumpsIntoFunctions(dump));
const passDumpsByFunction = mlirPassDumpParser.breakdownIntoPassDumpsByFunction(splitPassDumps);
const matchedPasses = mlirPassDumpParser.matchPassDumps(passDumpsByFunction);
expect(Object.keys(matchedPasses)).toContain('add_kernel');
const addKernelPasses = matchedPasses['add_kernel'];
expect(addKernelPasses.length).toBe(2); // We should have 2 passes (comparing the 3 dumps)
// Check the first pass (Inliner to Canonicalizer)
expect(addKernelPasses[0].name).toBe("Inliner (inline) ('builtin.module' operation)");
expect(addKernelPasses[0].irChanged).toBe(false); // No changes between these two dumps
// Check the second pass (Canonicalizer to TritonRewriteTensorPointer)
expect(addKernelPasses[1].name).toBe("Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)");
expect(addKernelPasses[1].irChanged).toBe(true); // There are changes between these two dumps
});
it('should process the complete output correctly', () => {
const optPipelineOptions = {
fullModule: false,
filterDebugInfo: false,
filterIRMetadata: false,
noDiscardValueNames: true,
demangle: true,
libraryFunctions: false,
};
const result = mlirPassDumpParser.process(deepCopy(rawMlirDump), {}, optPipelineOptions);
expect(Object.keys(result)).toContain('add_kernel');
expect(result['add_kernel'].length).toBe(2);
// Verify the passes are correctly identified
expect(result['add_kernel'][0].name).toBe("Inliner (inline) ('builtin.module' operation)");
expect(result['add_kernel'][1].name).toBe("Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)");
// Verify the IR changes are correctly detected
expect(result['add_kernel'][0].irChanged).toBe(false);
expect(result['add_kernel'][1].irChanged).toBe(true);
});
});

View File

@@ -101,6 +101,7 @@ export type LanguageKey =
| 'swift'
| 'tablegen'
| 'toit'
| 'triton'
| 'typescript'
| 'v'
| 'vala'