mirror of
https://github.com/compiler-explorer/compiler-explorer.git
synced 2025-12-27 10:33:59 -05:00
Add Triton language and compiler (#7919)
Close #5530. Infra: https://github.com/compiler-explorer/infra/pull/1711. Previous work by @siboehm at #5531 ## Summary This pull request introduces support for the [Triton](https://github.com/triton-lang/triton) language, a Python-based DSL for writing highly efficient GPU kernels. - [x] **New Language Support**: I've added comprehensive support for the Triton programming language, allowing users to compile and inspect Triton kernels within Compiler Explorer. (c.f., `lib/compilers/triton.ts`) - [x] **Python Wrapper for Compilation**: A new Python wrapper script (`triton_wrapper.py`) has been introduced to manage Triton compilation, patching its behavior to dump compiled kernels and intermediate representations without requiring actual execution, and consolidating the output for Compiler Explorer. - [x] **Device Assembly View**: Enables viewing of generated device assembly code (e.g., PTX, AMDGCN) and various intermediate representations (MLIR, LLVM IR) produced by the Triton compiler. - [x] **MLIR Parsing**: New parsers (`asm-parser-mlir.ts` and `mlir-pass-dump-parser.ts`) have been added to correctly interpret and display MLIR assembly and optimization pass dumps, including source location information. - [x] **Multi-Version & Multi-Backend Support**: Painstakingly includes all 8 versions (from 2.2.0 to 3.3.1) of Triton that supports Python 3.12. Supports both CUDA and HIP backend for Triton 3. ## Screenshots Source and assembly: <img width="1354" height="789" alt="image" src="https://github.com/user-attachments/assets/c29650ff-2073-40e0-a9e6-ff8377094b5e" /> Device view for MLIR and LLVM IR: <img width="1402" height="670" alt="image" src="https://github.com/user-attachments/assets/43dd5c68-ca78-41b1-9865-e97ffe3ef73c" /> Opt pipeline viewer: <img width="1408" height="668" alt="image" src="https://github.com/user-attachments/assets/429eef8c-aaac-4781-aafa-39ef0ffc7241" /> Diff of TTIR in Triton 3.3.1 vs 2.3.0: <img width="1580" height="726" alt="image" src="https://github.com/user-attachments/assets/a928c893-dd9a-4c3a-a048-14046e56a14c" /> CUDA & HIP: <img width="1596" height="800" alt="image" src="https://github.com/user-attachments/assets/c18800c3-cfad-4e5e-96de-ba92c9f236ea" /> ## Implementation Details (and Notes for Reviewers) - For Device Assembly View, I Implemented `MlirAsmParser` for parsing MLIR assembly. Technically MLIR is not an assembly language, but there is no better choice to make the source line map work w/ device view. - I Implemented `MlirPassDumpParser` for processing MLIR optimization pass dumps. I tried to subclass `LlvmPassDumpParser`, but they turn out to be too different to worth doing it. - `LlvmPassDumpParser` made some assumptions that do not hold true for MLIR passed. Some effort is put to make sure that the passes are properly diff-ed, since some passes can run multiple times and also sometimes pass can be nested (i.e., some number of `before`s followed by some number of `after`s) - A lot of effort is put into `patch_triton` to make sure that the we only compile the kernel without actually running it, and that needs to work across all the versions we support. ## Steps to Run Locally 1. Clone https://github.com/ShawnZhong/compiler-explorer-infra.git 2. Install Triton to `/opt/compiler-explorer/triton`: ```sh $ cd compiler-explorer-infra $ ./bin/ce_install install triton $ ls /opt/compiler-explorer/triton # v2.2.0 v2.3.0 v2.3.1 v3.0.0 v3.1.0 v3.2.0 v3.3.0 v3.3.1 ``` 3. Clone https://github.com/ShawnZhong/compiler-explorer.git and checkout branch `triton` 4. Run Compiler Explorer ```sh make EXTRA_ARGS='--language triton' dev ``` 5. Enjoy --------- Co-authored-by: Matt Godbolt <matt@godbolt.org>
This commit is contained in:
6
.github/labeler.yml
vendored
6
.github/labeler.yml
vendored
@@ -414,6 +414,12 @@
|
||||
- 'etc/config/toit.*.properties'
|
||||
- 'static/modes/toit-mode.ts'
|
||||
|
||||
'lang-triton':
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- 'lib/compilers/triton.ts'
|
||||
- 'etc/config/triton.*.properties'
|
||||
|
||||
'lang-typescript':
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
52
etc/config/triton.amazon.properties
Normal file
52
etc/config/triton.amazon.properties
Normal file
@@ -0,0 +1,52 @@
|
||||
compilers=&triton_nvidia:&triton_amd
|
||||
defaultCompiler=triton_nvidia_331
|
||||
compilerType=triton
|
||||
interpreted=true
|
||||
supportsBinary=false
|
||||
supportsExecute=false
|
||||
isSemVer=true
|
||||
notification=Experimental Triton support on Compiler Explorer. For tutorials, bugs reports, and feature requests, please visit <a href="https://github.com/ShawnZhong/compiler-explorer-triton">here</a>.
|
||||
|
||||
group.triton_nvidia.compilers=triton_nvidia_331:triton_nvidia_330:triton_nvidia_320:triton_nvidia_310:triton_nvidia_300:triton_nvidia_231:triton_nvidia_230
|
||||
group.triton_nvidia.groupName=Triton (Nvidia)
|
||||
group.triton_nvidia.options=--backend cuda --arch 90 --warp_size 32
|
||||
|
||||
compiler.triton_nvidia_331.name=Triton 3.3.1 (Nvidia)
|
||||
compiler.triton_nvidia_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
|
||||
|
||||
compiler.triton_nvidia_330.name=Triton 3.3.0 (Nvidia)
|
||||
compiler.triton_nvidia_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_320.name=Triton 3.2.0 (Nvidia)
|
||||
compiler.triton_nvidia_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_310.name=Triton 3.1.0 (Nvidia)
|
||||
compiler.triton_nvidia_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_300.name=Triton 3.0.0 (Nvidia)
|
||||
compiler.triton_nvidia_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_231.name=Triton 2.3.1 (Nvidia)
|
||||
compiler.triton_nvidia_231.exe=/opt/compiler-explorer/triton/v2.3.1/bin/python3
|
||||
|
||||
compiler.triton_nvidia_230.name=Triton 2.3.0 (Nvidia)
|
||||
compiler.triton_nvidia_230.exe=/opt/compiler-explorer/triton/v2.3.0/bin/python3
|
||||
|
||||
group.triton_amd.compilers=triton_amd_331:triton_amd_330:triton_amd_320:triton_amd_310:triton_amd_300
|
||||
group.triton_amd.groupName=Triton (AMD)
|
||||
group.triton_amd.options=--backend hip --arch gfx942 --warp_size 32
|
||||
|
||||
compiler.triton_amd_331.name=Triton 3.3.1 (AMD)
|
||||
compiler.triton_amd_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
|
||||
|
||||
compiler.triton_amd_330.name=Triton 3.3.0 (AMD)
|
||||
compiler.triton_amd_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
|
||||
|
||||
compiler.triton_amd_320.name=Triton 3.2.0 (AMD)
|
||||
compiler.triton_amd_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
|
||||
|
||||
compiler.triton_amd_310.name=Triton 3.1.0 (AMD)
|
||||
compiler.triton_amd_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
|
||||
|
||||
compiler.triton_amd_300.name=Triton 3.0.0 (AMD)
|
||||
compiler.triton_amd_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3
|
||||
52
etc/config/triton.defaults.properties
Normal file
52
etc/config/triton.defaults.properties
Normal file
@@ -0,0 +1,52 @@
|
||||
compilers=&triton_nvidia:&triton_amd
|
||||
defaultCompiler=triton_nvidia_331
|
||||
compilerType=triton
|
||||
interpreted=true
|
||||
supportsBinary=false
|
||||
supportsExecute=false
|
||||
isSemVer=true
|
||||
notification=Experimental Triton support on Compiler Explorer. For tutorials, bugs reports, and feature requests, please visit <a href="https://github.com/ShawnZhong/compiler-explorer-triton">here</a>.
|
||||
|
||||
group.triton_nvidia.compilers=triton_nvidia_331:triton_nvidia_330:triton_nvidia_320:triton_nvidia_310:triton_nvidia_300:triton_nvidia_231:triton_nvidia_230
|
||||
group.triton_nvidia.groupName=Triton (Nvidia)
|
||||
group.triton_nvidia.options=--backend cuda --arch 90 --warp_size 32
|
||||
|
||||
compiler.triton_nvidia_331.name=Triton 3.3.1 (Nvidia)
|
||||
compiler.triton_nvidia_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
|
||||
|
||||
compiler.triton_nvidia_330.name=Triton 3.3.0 (Nvidia)
|
||||
compiler.triton_nvidia_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_320.name=Triton 3.2.0 (Nvidia)
|
||||
compiler.triton_nvidia_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_310.name=Triton 3.1.0 (Nvidia)
|
||||
compiler.triton_nvidia_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_300.name=Triton 3.0.0 (Nvidia)
|
||||
compiler.triton_nvidia_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3
|
||||
|
||||
compiler.triton_nvidia_231.name=Triton 2.3.1 (Nvidia)
|
||||
compiler.triton_nvidia_231.exe=/opt/compiler-explorer/triton/v2.3.1/bin/python3
|
||||
|
||||
compiler.triton_nvidia_230.name=Triton 2.3.0 (Nvidia)
|
||||
compiler.triton_nvidia_230.exe=/opt/compiler-explorer/triton/v2.3.0/bin/python3
|
||||
|
||||
group.triton_amd.compilers=triton_amd_331:triton_amd_330:triton_amd_320:triton_amd_310:triton_amd_300
|
||||
group.triton_amd.groupName=Triton (AMD)
|
||||
group.triton_amd.options=--backend hip --arch gfx942 --warp_size 32
|
||||
|
||||
compiler.triton_amd_331.name=Triton 3.3.1 (AMD)
|
||||
compiler.triton_amd_331.exe=/opt/compiler-explorer/triton/v3.3.1/bin/python3
|
||||
|
||||
compiler.triton_amd_330.name=Triton 3.3.0 (AMD)
|
||||
compiler.triton_amd_330.exe=/opt/compiler-explorer/triton/v3.3.0/bin/python3
|
||||
|
||||
compiler.triton_amd_320.name=Triton 3.2.0 (AMD)
|
||||
compiler.triton_amd_320.exe=/opt/compiler-explorer/triton/v3.2.0/bin/python3
|
||||
|
||||
compiler.triton_amd_310.name=Triton 3.1.0 (AMD)
|
||||
compiler.triton_amd_310.exe=/opt/compiler-explorer/triton/v3.1.0/bin/python3
|
||||
|
||||
compiler.triton_amd_300.name=Triton 3.0.0 (AMD)
|
||||
compiler.triton_amd_300.exe=/opt/compiler-explorer/triton/v3.0.0/bin/python3
|
||||
254
etc/scripts/triton_wrapper.py
Normal file
254
etc/scripts/triton_wrapper.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import argparse
|
||||
import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
|
||||
class MockCacheManager(triton.runtime.cache.CacheManager):
|
||||
"""
|
||||
A mock cache manager that dumps the intermediate files to a given output path.
|
||||
|
||||
There are various ways to dump the intermediate files:
|
||||
1. The most obvious way is to use the `TRITON_KERNEL_DUMP` & ``TRITON_DUMP_DIR`
|
||||
environment variables. e.g.,
|
||||
os.environ["TRITON_KERNEL_DUMP"] = "1"
|
||||
os.environ["TRITON_DUMP_DIR"] = str(output_dir)
|
||||
However, `TRITON_DUMP_DIR` is introduced in Triton v3.2.0 at
|
||||
https://github.com/triton-lang/triton/commit/ca469d7b6b6def316b5f5ee6ad2bd19dcb840bd8,
|
||||
and thus not available in older versions.
|
||||
|
||||
2. The second attempt is to patch the `default_cache_dir` function. e.g.,
|
||||
triton.runtime.cache.default_cache_dir = MagicMock(return_value=output_dir)
|
||||
This is a bit hacky, and less flexible in terms of controlling the file output.
|
||||
(In fact, Triton dumps the compiled kernels to a folder with a random name.)
|
||||
|
||||
3. Another option is to use various hooks in Triton. e.g.,
|
||||
triton.knobs.runtime.{jit_post_compile_hook,launch_enter_hook}
|
||||
JITFunction.{compiled_hook,cache_hook}
|
||||
This approach is taken by TritonParse(https://github.com/pytorch-labs/tritonparse),
|
||||
but it does not support older versions of Triton prior to the following commits:
|
||||
https://github.com/triton-lang/triton/commit/0e9267202532ed1709dcc12c636220cf239dc377,
|
||||
https://github.com/triton-lang/triton/commit/850525276426fb9814399a8e0ee8fdf744229b02.
|
||||
|
||||
4. The current apporach is to mock a `CacheManager` class. This is the most flexible
|
||||
approach, and works for all versions of Triton.
|
||||
"""
|
||||
|
||||
output_file: Path
|
||||
|
||||
def __init__(self, key, override=False, dump=False):
|
||||
self.dump = dump
|
||||
# filename -> data
|
||||
self.files = {}
|
||||
# filename -> group dict
|
||||
self.groups = {}
|
||||
# current stage for a given kernel
|
||||
self.stage = 0
|
||||
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
return self.files.get(filename, None)
|
||||
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
name = Path(filename).stem
|
||||
suffix = Path(filename).suffix
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
|
||||
# Write the final file to the output file, so that we can view it in the default assembly view.
|
||||
if suffix in (".ptx", ".amdgcn"):
|
||||
with open(MockCacheManager.output_file, "a") as f:
|
||||
f.write(data)
|
||||
f.write("\n\n")
|
||||
|
||||
# Write intermediate files to the output file, so that we can see them in the Device View.
|
||||
self.stage += 1
|
||||
if suffix == ".json":
|
||||
path = MockCacheManager.output_file.parent / filename
|
||||
with open(path, "w") as fout:
|
||||
json.dump(json.loads(data), fout, indent=2)
|
||||
else:
|
||||
path = (
|
||||
MockCacheManager.output_file.parent
|
||||
/ f"{name} [stage {self.stage}]{suffix}"
|
||||
)
|
||||
if not binary:
|
||||
with open(path, "w") as fout:
|
||||
fout.write(data)
|
||||
elif suffix == ".cubin":
|
||||
try:
|
||||
# The try-catch is needed because `disasm` was broken in Triton v3.0.0 and v3.1.0. See
|
||||
# https://github.com/triton-lang/triton/commit/f424f656b3528c47d8c48126cdccafca29e536ae
|
||||
from triton.tools import disasm
|
||||
|
||||
with open(path.with_suffix(".sass"), "w") as fout:
|
||||
fout.write(disasm.get_sass(data))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Write the file to the "cache"
|
||||
self.files[filename] = data
|
||||
return filename
|
||||
|
||||
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
|
||||
self.groups.get(filename, None)
|
||||
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
self.groups[filename] = group
|
||||
|
||||
|
||||
def setup_triton(
|
||||
output_file: Path,
|
||||
opt_pipeline_file: Path,
|
||||
backend: str,
|
||||
arch: Union[int, str],
|
||||
warp_size: int,
|
||||
):
|
||||
"""
|
||||
Patch Triton to dump the compiled kernels to output dir without actually running them.
|
||||
|
||||
This is needed because
|
||||
1. Triton does not easily support such use case. There exists an AOT compiler at
|
||||
https://github.com/triton-lang/triton/blob/main/python/triton/tools/compile.py,
|
||||
but it requires a bunch of boilerplate code and also requires additional user
|
||||
input to specify the kernel name, signature, etc.
|
||||
2. Even if Triton adds such support, older versions of Triton (e.g., v2.3.x) still
|
||||
requirs such patching to work.
|
||||
|
||||
This function is a collection of hacks. It has been tested to work with Triton
|
||||
2.3.0, 2.3.1, 3.0.0, 3.1.0, 3.2.0, 3.3.0, 3.3.1.
|
||||
"""
|
||||
|
||||
os.environ["TRITON_ALWAYS_COMPILE"] = "1"
|
||||
if opt_pipeline_file:
|
||||
os.environ["MLIR_ENABLE_DUMP"] = "1"
|
||||
# Supported in Triton v3.3.0 and later since
|
||||
# https://github.com/triton-lang/triton/commit/3d7d9e33e7e4cba17dc366d207af2c657bd4fbd1
|
||||
os.environ["MLIR_DUMP_PATH"] = str(opt_pipeline_file)
|
||||
else:
|
||||
# Disable dumping other files w/ opt_pipeline_file since they race with each other
|
||||
os.environ["TRITON_CACHE_MANAGER"] = "__main__:MockCacheManager"
|
||||
MockCacheManager.output_file = output_file
|
||||
|
||||
# Usually, Triton compiles and run a kernel when we call `kernel[grid](args)`.
|
||||
# However, we want to dump the compiled kernel without actually running it.
|
||||
# The class `CompiledKernel` represents a handle to a compiled kernel,
|
||||
# ready to be launched. We patch it to be a no-op.
|
||||
triton.compiler.compiler.CompiledKernel = MagicMock()
|
||||
|
||||
# We mock a GPU driver to avoid the need to initialize CUDA/ROCm.
|
||||
# The driver is only used in runtime instead of compile time,
|
||||
# so it's safe to do this.
|
||||
def get_current_target():
|
||||
try:
|
||||
from triton.compiler.compiler import GPUTarget
|
||||
|
||||
return GPUTarget(backend=backend, arch=arch, warp_size=warp_size)
|
||||
except ImportError:
|
||||
# For Triton v2.3.x, we don't have GPUTarget
|
||||
return (backend, arch)
|
||||
|
||||
mockGPUDriver = MagicMock(
|
||||
get_current_target=get_current_target,
|
||||
get_benchmarker=lambda: MagicMock(return_value=[0.0]),
|
||||
)
|
||||
|
||||
# Set the active driver to the mocked one.
|
||||
# `DriverConfig` and `triton.runtime.driver.set_active` is introduced in Triton v3.0.0 at
|
||||
# https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4
|
||||
# For older versions of Triton, we directly assign to the `_obj` field of `LazyProxy`.
|
||||
try:
|
||||
from triton.runtime.driver import DriverConfig
|
||||
|
||||
triton.runtime.driver.set_active(mockGPUDriver)
|
||||
except ImportError:
|
||||
triton.runtime.driver._obj = mockGPUDriver
|
||||
|
||||
# For Triton v2.3.x, there are some driver code that goes into
|
||||
# the generic code path, so we need to patch it as well.
|
||||
try:
|
||||
from triton.compiler.backends.cuda import CUDABackend
|
||||
|
||||
CUDABackend.make_launcher_stub = MagicMock()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def main(
|
||||
input_file: Path,
|
||||
output_file: Path,
|
||||
opt_pipeline_file: Path,
|
||||
backend: str,
|
||||
arch: Union[int, str],
|
||||
warp_size: int,
|
||||
):
|
||||
setup_triton(output_file, opt_pipeline_file, backend, arch, warp_size)
|
||||
|
||||
# Run the script by importing it as a module
|
||||
spec = importlib.util.spec_from_file_location("example", input_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
with FakeTensorMode():
|
||||
# Use FakeTensor (developed during Dynamo) to avoid actually creating tensors
|
||||
# See https://docs.pytorch.org/docs/stable/torch.compiler_fake_tensor.html
|
||||
# Also set the data_ptr to 0 to avoid PyTorch warning and make alignment check happy
|
||||
torch._subclasses.FakeTensor.data_ptr = MagicMock(return_value=0)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Triton wrapper")
|
||||
parser.add_argument(
|
||||
"input_file",
|
||||
type=Path,
|
||||
help="Path to the input Python file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the output file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opt_pipeline_file",
|
||||
type=Path,
|
||||
help="Path to the output opt pipeline file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "hip"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
type=str,
|
||||
default=None, # Default value set later based on backend
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warp_size",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set some sane defaults for the arch
|
||||
if args.arch is None:
|
||||
if args.backend == "cuda":
|
||||
args.arch = 90
|
||||
elif args.backend == "hip":
|
||||
args.arch = "gfx942"
|
||||
|
||||
# Triton expects the arch to be an int for CUDA and a string for HIP
|
||||
if args.backend == "cuda":
|
||||
args.arch = int(args.arch)
|
||||
|
||||
main(**vars(args))
|
||||
28
examples/triton/default.py
Normal file
28
examples/triton/default.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
x = torch.rand(1024)
|
||||
y = torch.rand(1024)
|
||||
output = torch.empty(1024)
|
||||
n_elements = output.numel()
|
||||
add_kernel[(1,)](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||
@@ -154,6 +154,7 @@ export {TenDRACompiler} from './tendra.js';
|
||||
export {TIC2000} from './tic2000.js';
|
||||
export {TinyCCompiler} from './tinyc.js';
|
||||
export {ToitCompiler} from './toit.js';
|
||||
export {TritonCompiler} from './triton.js';
|
||||
export {TurboCCompiler} from './turboc.js';
|
||||
export {TypeScriptNativeCompiler} from './typescript-native.js';
|
||||
export {VCompiler} from './v.js';
|
||||
|
||||
201
lib/compilers/triton.ts
Normal file
201
lib/compilers/triton.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright (c) 2025, Compiler Explorer Authors
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
// POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import * as fs from 'node:fs/promises';
|
||||
import Path from 'node:path';
|
||||
import type {CompilationInfo, CompilationResult} from '../../types/compilation/compilation.interfaces.js';
|
||||
import type {
|
||||
OptPipelineBackendOptions,
|
||||
OptPipelineOutput,
|
||||
} from '../../types/compilation/opt-pipeline-output.interfaces.js';
|
||||
import type {PreliminaryCompilerInfo} from '../../types/compiler.interfaces.js';
|
||||
import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js';
|
||||
import {BaseCompiler} from '../base-compiler.js';
|
||||
import {CompilationEnvironment} from '../compilation-env.js';
|
||||
import type {IAsmParser} from '../parsers/asm-parser.interfaces.js';
|
||||
import {AmdgpuAsmParser} from '../parsers/asm-parser-amdgpu.js';
|
||||
import {MlirAsmParser} from '../parsers/asm-parser-mlir.js';
|
||||
import {PTXAsmParser} from '../parsers/asm-parser-ptx.js';
|
||||
import {SassAsmParser} from '../parsers/asm-parser-sass.js';
|
||||
import {MlirPassDumpParser} from '../parsers/mlir-pass-dump-parser.js';
|
||||
import {parseOutput, resolvePathFromAppRoot} from '../utils.js';
|
||||
import {BaseParser} from './argument-parsers.js';
|
||||
|
||||
export class TritonCompiler extends BaseCompiler {
|
||||
private compilerWrapperPath: string;
|
||||
|
||||
static get key() {
|
||||
return 'triton';
|
||||
}
|
||||
|
||||
parserMap: Record<string, IAsmParser>;
|
||||
mlirPassDumpParser: MlirPassDumpParser;
|
||||
|
||||
constructor(compilerInfo: PreliminaryCompilerInfo, env: CompilationEnvironment) {
|
||||
super(compilerInfo, env);
|
||||
|
||||
this.compilerWrapperPath =
|
||||
this.compilerProps('compilerWrapper', '') || resolvePathFromAppRoot('etc', 'scripts', 'triton_wrapper.py');
|
||||
|
||||
// Enable the Opt Pipeline view
|
||||
this.compiler.optPipeline = {};
|
||||
// Used to parse the output of the opt pipeline
|
||||
this.mlirPassDumpParser = new MlirPassDumpParser(this.compilerProps);
|
||||
|
||||
// Enable the Device Viewer
|
||||
this.compiler.supportsDeviceAsmView = true;
|
||||
// Define parsers for the different output files displayed in the Device Viewer
|
||||
const sassAsmParser = new SassAsmParser(this.compilerProps);
|
||||
const ptxAsmParser = new PTXAsmParser(this.compilerProps);
|
||||
const amdgpuAsmParser = new AmdgpuAsmParser();
|
||||
const mlirAsmParser = new MlirAsmParser();
|
||||
this.parserMap = {
|
||||
'.ttir': mlirAsmParser,
|
||||
'.ttgir': mlirAsmParser,
|
||||
'.ptx': ptxAsmParser,
|
||||
'.sass': sassAsmParser,
|
||||
'.source': mlirAsmParser,
|
||||
'.amdgcn': amdgpuAsmParser,
|
||||
'.llir': mlirAsmParser,
|
||||
'.json': sassAsmParser,
|
||||
};
|
||||
|
||||
if (compilerInfo.group == 'triton_amd') {
|
||||
this.asm = amdgpuAsmParser;
|
||||
} else if (compilerInfo.group == 'triton_nvidia') {
|
||||
this.asm = ptxAsmParser;
|
||||
}
|
||||
}
|
||||
|
||||
override optionsForFilter(filters: ParseFiltersAndOutputOptions, outputFilename: string): string[] {
|
||||
// See etc/scripts/triton_wrapper.py for the options
|
||||
return ['-I', this.compilerWrapperPath, '--output_file', outputFilename];
|
||||
}
|
||||
|
||||
override getArgumentParserClass() {
|
||||
return BaseParser;
|
||||
}
|
||||
|
||||
override async extractDeviceCode(
|
||||
result: CompilationResult,
|
||||
filters: ParseFiltersAndOutputOptions,
|
||||
compilationInfo: CompilationInfo,
|
||||
) {
|
||||
const devices = {...result.devices};
|
||||
|
||||
const {dirPath} = result;
|
||||
if (!dirPath) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Extract the device code from the output directory
|
||||
const files = await fs.readdir(dirPath);
|
||||
await Promise.all(
|
||||
files.map(async filename => {
|
||||
const ext = Path.extname(filename);
|
||||
const parser = this.parserMap[ext];
|
||||
if (!parser) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Read the file
|
||||
const data = await fs.readFile(Path.join(dirPath, filename), 'utf8');
|
||||
|
||||
// Parse the assembly with line numbers
|
||||
let device;
|
||||
if (ext === '.llir') {
|
||||
device = await this.llvmIr.process(data, {
|
||||
filterDebugInfo: false,
|
||||
filterIRMetadata: false,
|
||||
filterAttributes: false,
|
||||
filterComments: false,
|
||||
noDiscardValueNames: false,
|
||||
demangle: false,
|
||||
});
|
||||
} else {
|
||||
device = await parser.process(data, filters);
|
||||
}
|
||||
|
||||
Object.assign(devices, {[filename]: device});
|
||||
}),
|
||||
);
|
||||
result.devices = devices;
|
||||
return result;
|
||||
}
|
||||
|
||||
override async generateOptPipeline(
|
||||
inputFilename: string,
|
||||
options: string[],
|
||||
filters: ParseFiltersAndOutputOptions,
|
||||
optPipelineOptions: OptPipelineBackendOptions,
|
||||
): Promise<OptPipelineOutput | undefined> {
|
||||
// Call the script to generate the opt pipeline
|
||||
const execOptions = this.getDefaultExecOptions();
|
||||
const outputFilename = Path.join(Path.dirname(inputFilename), 'opt_pipeline.txt');
|
||||
const optOptions = [...options, '--opt_pipeline_file', outputFilename];
|
||||
|
||||
const compileStart = performance.now();
|
||||
await this.runCompiler(this.compiler.exe, optOptions, inputFilename, execOptions);
|
||||
const compileEnd = performance.now();
|
||||
|
||||
// Read the output file and parse it
|
||||
try {
|
||||
const rawText = await fs.readFile(outputFilename, {encoding: 'utf8'});
|
||||
const lines = parseOutput(rawText);
|
||||
|
||||
const parseStart = performance.now();
|
||||
const llvmOptPipeline = await this.mlirPassDumpParser.process(lines, filters, optPipelineOptions);
|
||||
const parseEnd = performance.now();
|
||||
|
||||
return {
|
||||
results: llvmOptPipeline,
|
||||
compileTime: compileEnd - compileStart,
|
||||
parseTime: parseEnd - parseStart,
|
||||
};
|
||||
} catch (e: any) {
|
||||
return {
|
||||
error: e.toString(),
|
||||
results: {},
|
||||
compileTime: compileEnd - compileStart,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
override getDefaultFilters() {
|
||||
return {
|
||||
intel: false,
|
||||
commentOnly: false,
|
||||
directives: false,
|
||||
labels: false,
|
||||
optOutput: true,
|
||||
binary: false,
|
||||
execute: false,
|
||||
demangle: false,
|
||||
libraryCode: false,
|
||||
trim: false,
|
||||
binaryObject: false,
|
||||
debugCalls: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -894,6 +894,17 @@ const definitions: Record<LanguageKey, LanguageDefinition> = {
|
||||
previewFilter: null,
|
||||
monacoDisassembly: null,
|
||||
},
|
||||
triton: {
|
||||
name: 'Triton',
|
||||
monaco: 'python',
|
||||
extensions: ['.py'],
|
||||
alias: [],
|
||||
logoFilename: 'triton.png',
|
||||
logoFilenameDark: null,
|
||||
formatter: null,
|
||||
previewFilter: null,
|
||||
monacoDisassembly: null,
|
||||
},
|
||||
typescript: {
|
||||
name: 'TypeScript Native',
|
||||
monaco: 'typescript',
|
||||
|
||||
146
lib/parsers/asm-parser-mlir.ts
Normal file
146
lib/parsers/asm-parser-mlir.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2025, Compiler Explorer Authors
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
// POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import {AsmResultSource, ParsedAsmResult, ParsedAsmResultLine} from '../../types/asmresult/asmresult.interfaces.js';
|
||||
import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js';
|
||||
import * as utils from '../utils.js';
|
||||
|
||||
import {AsmParser} from './asm-parser.js';
|
||||
|
||||
export class MlirAsmParser extends AsmParser {
|
||||
protected locDefRegex: RegExp;
|
||||
protected locDefUnknownRegex: RegExp;
|
||||
protected locRefRegex: RegExp;
|
||||
protected locRefRegexReplace: RegExp;
|
||||
protected inlineLocRegex: RegExp;
|
||||
protected inlineLocRegexReplace: RegExp;
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
|
||||
// Match location definitions like #loc1 = loc("/path/to/file":line:column)
|
||||
this.locDefRegex = /^#(\w+)\s*=\s*loc\("([^"]+)":(\d+):(\d+)\)/;
|
||||
|
||||
// Match location definitions like #loc1 = loc(unknown)
|
||||
this.locDefUnknownRegex = /^#(\w+)\s*=\s*loc\(unknown\)/;
|
||||
|
||||
// Match location references like loc(#loc1)
|
||||
this.locRefRegex = /\s*loc\(#(\w+)\)/;
|
||||
this.locRefRegexReplace = new RegExp(this.locRefRegex.source, 'g');
|
||||
|
||||
// Match inline locations like loc("/path/to/file":line:column)
|
||||
this.inlineLocRegex = /\s*loc\("([^"]+)":(\d+):(\d+)\)/;
|
||||
this.inlineLocRegexReplace = new RegExp(this.inlineLocRegex.source, 'g');
|
||||
}
|
||||
|
||||
override processAsm(asmResult: string, filters: ParseFiltersAndOutputOptions): ParsedAsmResult {
|
||||
const startTime = process.hrtime.bigint();
|
||||
|
||||
const asm: ParsedAsmResultLine[] = [];
|
||||
const asmLines = utils.splitLines(asmResult);
|
||||
const startingLineCount = asmLines.length;
|
||||
|
||||
// First pass: extract all location definitions
|
||||
const locationMap = new Map<string, AsmResultSource>();
|
||||
for (const line of asmLines) {
|
||||
const locMatch = line.match(this.locDefRegex);
|
||||
if (locMatch) {
|
||||
const locId = locMatch[1];
|
||||
const file = locMatch[2];
|
||||
const lineNum = Number.parseInt(locMatch[3], 10);
|
||||
const column = Number.parseInt(locMatch[4], 10);
|
||||
|
||||
locationMap.set(locId, {
|
||||
file: utils.maskRootdir(file),
|
||||
line: lineNum,
|
||||
column: column,
|
||||
mainsource: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: process each line and associate with source information
|
||||
for (const line of asmLines) {
|
||||
// Skip location definition lines
|
||||
if (this.locDefRegex.test(line) || this.locDefUnknownRegex.test(line)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply filters if needed
|
||||
let processedLine = line;
|
||||
if (filters.trim) {
|
||||
processedLine = processedLine.trim();
|
||||
}
|
||||
|
||||
if (filters.commentOnly && processedLine.trim().startsWith('//')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find source information from location references
|
||||
let source: AsmResultSource | null = null;
|
||||
|
||||
// Check for location references like loc(#loc1)
|
||||
const locRefMatch = line.match(this.locRefRegex);
|
||||
if (locRefMatch) {
|
||||
const locId = locRefMatch[1];
|
||||
source = locationMap.get(locId) || null;
|
||||
// Remove location reference from the displayed text
|
||||
processedLine = processedLine.replace(this.locRefRegexReplace, '');
|
||||
} else {
|
||||
// Check for inline locations like loc("/path/to/file":line:column)
|
||||
const inlineLocMatch = line.match(this.inlineLocRegex);
|
||||
if (inlineLocMatch) {
|
||||
const file = inlineLocMatch[1];
|
||||
const lineNum = Number.parseInt(inlineLocMatch[2], 10);
|
||||
const column = Number.parseInt(inlineLocMatch[3], 10);
|
||||
|
||||
source = {
|
||||
file: utils.maskRootdir(file),
|
||||
line: lineNum,
|
||||
column: column,
|
||||
mainsource: true,
|
||||
};
|
||||
}
|
||||
// Remove inline location from the displayed text
|
||||
processedLine = processedLine.replace(this.inlineLocRegexReplace, '');
|
||||
}
|
||||
|
||||
// Add the line to the result
|
||||
asm.push({
|
||||
text: processedLine,
|
||||
source: source,
|
||||
labels: [],
|
||||
});
|
||||
}
|
||||
|
||||
const endTime = process.hrtime.bigint();
|
||||
return {
|
||||
asm: asm,
|
||||
labelDefinitions: {},
|
||||
languageId: 'mlir',
|
||||
parsingTime: utils.deltaTimeNanoToMili(startTime, endTime),
|
||||
filteredCount: startingLineCount - asm.length,
|
||||
};
|
||||
}
|
||||
}
|
||||
374
lib/parsers/mlir-pass-dump-parser.ts
Normal file
374
lib/parsers/mlir-pass-dump-parser.ts
Normal file
@@ -0,0 +1,374 @@
|
||||
// Copyright (c) 2025, Compiler Explorer Authors
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
// POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import {
|
||||
OptPipelineBackendOptions,
|
||||
OptPipelineResults,
|
||||
Pass,
|
||||
} from '../../types/compilation/opt-pipeline-output.interfaces.js';
|
||||
import {ParseFiltersAndOutputOptions} from '../../types/features/filters.interfaces.js';
|
||||
import {ResultLine} from '../../types/resultline/resultline.interfaces.js';
|
||||
import {assert} from '../assert.js';
|
||||
import {PropertyGetter} from '../properties.interfaces.js';
|
||||
|
||||
// Helper function to extract pass name from header
|
||||
function extractPassName(header: string): string {
|
||||
if (header.startsWith('IR Dump Before ')) {
|
||||
return header.slice('IR Dump Before '.length);
|
||||
}
|
||||
|
||||
if (header.startsWith('IR Dump After ')) {
|
||||
let passName = header.slice('IR Dump After '.length);
|
||||
// Handle invalidated passes
|
||||
if (passName.endsWith(' (invalidated)')) {
|
||||
passName = passName.slice(0, passName.length - ' (invalidated)'.length);
|
||||
}
|
||||
return passName;
|
||||
}
|
||||
assert(false, 'Unexpected pass header', header);
|
||||
}
|
||||
|
||||
// Ir Dump for a pass with raw lines
|
||||
type PassDump = {
|
||||
header: string;
|
||||
affectedFunction: string | undefined;
|
||||
lines: ResultLine[];
|
||||
};
|
||||
|
||||
// Ir Dump for a pass with raw lines broken into affected functions
|
||||
type SplitPassDump = {
|
||||
header: string;
|
||||
functions: Record<string, ResultLine[]>;
|
||||
};
|
||||
|
||||
export class MlirPassDumpParser {
|
||||
locationDefine: RegExp;
|
||||
locationReference: RegExp;
|
||||
irDumpHeader: RegExp;
|
||||
functionDefine: RegExp;
|
||||
moduleDefine: RegExp;
|
||||
functionEnd: RegExp;
|
||||
moduleEnd: RegExp;
|
||||
|
||||
constructor(compilerProps: PropertyGetter) {
|
||||
// Location definitions: #loc0 = loc(...)
|
||||
this.locationDefine = /^#loc\d* = loc\(.+\)$/;
|
||||
|
||||
// Location references: loc(#loc0), loc("/app/example.py":19:0), loc(unknown)
|
||||
this.locationReference = /\s*loc\([^)]*\)/g;
|
||||
|
||||
// MLIR dump headers look like "// -----// IR Dump Before/After XYZ (xyz) ('operation-type' operation: @function_name) //----- //"
|
||||
this.irDumpHeader = /^\/\/ -----\/\/ (IR Dump (?:Before|After) .+) \/\/----- \/\/$/;
|
||||
|
||||
// MLIR function definitions look like "func.func @function_name(...) {"
|
||||
// or "tt.func public @function_name(...) {"
|
||||
this.functionDefine = /^\s*(\w+\.func\s+(?:\w+\s+)?@(\w+).*\{)$/;
|
||||
|
||||
// MLIR module definitions look like "module {"
|
||||
this.moduleDefine = /^\s*(module\s*\{)$/;
|
||||
|
||||
// Functions end with a closing brace
|
||||
this.functionEnd = /^\s*\}\s*$/;
|
||||
|
||||
// Modules end with a closing brace
|
||||
this.moduleEnd = /^\s*\}\s*$/;
|
||||
}
|
||||
|
||||
breakdownOutputIntoPassDumps(ir: ResultLine[]) {
|
||||
// break down output by "// -----// IR Dump Before/After XYZ //----- //" markers
|
||||
const raw_passes: PassDump[] = [];
|
||||
let pass: PassDump | null = null;
|
||||
let lastWasBlank = false; // skip duplicate blank lines
|
||||
|
||||
for (const line of ir) {
|
||||
const irMatch = line.text.match(this.irDumpHeader);
|
||||
|
||||
if (irMatch) {
|
||||
if (pass !== null) {
|
||||
raw_passes.push(pass);
|
||||
}
|
||||
|
||||
const headerText = irMatch[1];
|
||||
pass = {
|
||||
header: headerText,
|
||||
affectedFunction: undefined,
|
||||
lines: [],
|
||||
};
|
||||
lastWasBlank = true; // skip leading newlines after the header
|
||||
} else {
|
||||
if (pass === null) continue;
|
||||
|
||||
if (line.text.trim() === '') {
|
||||
if (!lastWasBlank) {
|
||||
pass.lines.push(line);
|
||||
}
|
||||
lastWasBlank = true;
|
||||
} else {
|
||||
pass.lines.push(line);
|
||||
lastWasBlank = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (pass !== null) {
|
||||
raw_passes.push(pass);
|
||||
}
|
||||
|
||||
return raw_passes;
|
||||
}
|
||||
|
||||
breakdownPassDumpsIntoFunctions(dump: PassDump) {
|
||||
// Simplified version based on the assumption that:
|
||||
// 1. Functions always live inside a single module
|
||||
// 2. A single module always has a single function in it
|
||||
// 3. We use the name of the function to show the entire module
|
||||
const pass: SplitPassDump = {
|
||||
header: dump.header,
|
||||
functions: {},
|
||||
};
|
||||
|
||||
// Find the function name inside the module
|
||||
let functionName: string | null = null;
|
||||
for (const line of dump.lines) {
|
||||
const funcMatch = line.text.match(this.functionDefine);
|
||||
if (funcMatch) {
|
||||
functionName = funcMatch[2];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If we found a function name, use it; otherwise use "module"
|
||||
const name = functionName || 'module';
|
||||
pass.functions[name] = dump.lines;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
breakdownIntoPassDumpsByFunction(passDumps: SplitPassDump[]) {
|
||||
// Currently we have an array of passes with a map of functions altered in each pass
|
||||
// We want to transpose to a map of functions with an array of passes on the function
|
||||
const passDumpsByFunction: Record<string, PassDump[]> = {};
|
||||
|
||||
for (const pass of passDumps) {
|
||||
const {header, functions} = pass;
|
||||
|
||||
for (const [function_name, lines] of Object.entries(functions)) {
|
||||
if (!(function_name in passDumpsByFunction)) {
|
||||
passDumpsByFunction[function_name] = [];
|
||||
}
|
||||
|
||||
passDumpsByFunction[function_name].push({
|
||||
header,
|
||||
affectedFunction: undefined,
|
||||
lines,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return passDumpsByFunction;
|
||||
}
|
||||
|
||||
associateFullDumpsWithFunctions(passDumps: PassDump[]) {
|
||||
// Currently we have an array of passes that'll have target annotations
|
||||
const passDumpsByFunction: Record<string, PassDump[]> = {};
|
||||
|
||||
// First figure out what all the functions are
|
||||
for (const pass of passDumps) {
|
||||
if (pass.affectedFunction) {
|
||||
passDumpsByFunction[pass.affectedFunction] = [];
|
||||
}
|
||||
}
|
||||
|
||||
// Add a special entry for the full module
|
||||
passDumpsByFunction['<Full Module>'] = [];
|
||||
|
||||
for (const pass of passDumps) {
|
||||
const {header, affectedFunction, lines} = pass;
|
||||
|
||||
if (affectedFunction) {
|
||||
// This pass affects a specific function
|
||||
assert(affectedFunction in passDumpsByFunction);
|
||||
|
||||
// Add to both the specific function and the full module view
|
||||
[passDumpsByFunction[affectedFunction], passDumpsByFunction['<Full Module>']].map(entry =>
|
||||
entry.push({
|
||||
header: `${header} (${affectedFunction})`,
|
||||
affectedFunction,
|
||||
lines,
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
// This pass applies to everything
|
||||
for (const [_, entry] of Object.entries(passDumpsByFunction)) {
|
||||
entry.push({
|
||||
header,
|
||||
affectedFunction: undefined,
|
||||
lines,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return passDumpsByFunction;
|
||||
}
|
||||
|
||||
isIrChanged(before: ResultLine[], after: ResultLine[]) {
|
||||
if (before.length !== after.length) {
|
||||
return true;
|
||||
}
|
||||
for (let i = 0; i < before.length; i++) {
|
||||
if (before[i].text !== after[i].text) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
matchPassDumps(passDumpsByFunction: Record<string, PassDump[]>) {
|
||||
// We have all the passes for each function, now we will go through each function and match the before/after
|
||||
// dumps, handling the case where the same pass might occur multiple times
|
||||
const final_output: OptPipelineResults = {};
|
||||
|
||||
for (const [function_name, passDumps] of Object.entries(passDumpsByFunction)) {
|
||||
const passes: Pass[] = [];
|
||||
|
||||
// Use a stack of "Before" passes
|
||||
const beforePasses: PassDump[] = [];
|
||||
|
||||
// Collect all "Before" passes in order
|
||||
for (const dump of passDumps) {
|
||||
if (dump.header.startsWith('IR Dump Before ')) {
|
||||
beforePasses.push(dump);
|
||||
}
|
||||
}
|
||||
|
||||
// Process "After" passes and match with "Before" passes
|
||||
for (const dump of passDumps) {
|
||||
if (dump.header.startsWith('IR Dump After ')) {
|
||||
const afterPassName = extractPassName(dump.header);
|
||||
const pass: Pass = {
|
||||
name: afterPassName,
|
||||
machine: false,
|
||||
after: dump.lines,
|
||||
before: [],
|
||||
irChanged: true,
|
||||
};
|
||||
|
||||
// Find matching "Before" pass by name
|
||||
for (let i = 0; i < beforePasses.length; i++) {
|
||||
const beforePassName = extractPassName(beforePasses[i].header);
|
||||
if (beforePassName === afterPassName) {
|
||||
// Found a match, use it and remove from the stack
|
||||
pass.before = beforePasses[i].lines;
|
||||
|
||||
// Check for equality
|
||||
pass.irChanged = this.isIrChanged(pass.before, pass.after);
|
||||
|
||||
// Remove the matched "Before" pass
|
||||
beforePasses.splice(i, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
passes.push(pass);
|
||||
}
|
||||
}
|
||||
|
||||
// If we only have before passes (no after passes), diff between consecutive before passes
|
||||
// This happened in Triton since it sets enableIRPrinting(printAfterOnlyOnFailure=false)
|
||||
if (passes.length === 0) {
|
||||
for (let i = 0; i < beforePasses.length - 1; i++) {
|
||||
const isLast = i === beforePasses.length - 1;
|
||||
const passName = extractPassName(beforePasses[i].header);
|
||||
const before = beforePasses[i].lines;
|
||||
const after = isLast ? beforePasses[i].lines : beforePasses[i + 1].lines;
|
||||
const irChanged = isLast ? false : this.isIrChanged(before, after);
|
||||
const pass: Pass = {
|
||||
name: passName,
|
||||
machine: false,
|
||||
before: before,
|
||||
after: after,
|
||||
irChanged: irChanged,
|
||||
};
|
||||
passes.push(pass);
|
||||
}
|
||||
} else {
|
||||
// Handle any remaining "Before" passes that don't have corresponding "After" passes
|
||||
for (const beforeDump of beforePasses) {
|
||||
const passName = extractPassName(beforeDump.header);
|
||||
const pass: Pass = {
|
||||
name: passName,
|
||||
machine: false,
|
||||
before: beforeDump.lines,
|
||||
after: [],
|
||||
irChanged: true, // Assume changed since there's no "After" to compare with
|
||||
};
|
||||
passes.push(pass);
|
||||
}
|
||||
}
|
||||
|
||||
final_output[function_name] = passes;
|
||||
}
|
||||
|
||||
return final_output;
|
||||
}
|
||||
|
||||
breakdownOutput(ir: ResultLine[], optPipelineOptions: OptPipelineBackendOptions) {
|
||||
// break down output by "// -----// IR Dump Before/After XYZ //----- //" markers
|
||||
const raw_passes = this.breakdownOutputIntoPassDumps(ir);
|
||||
|
||||
if (optPipelineOptions.fullModule) {
|
||||
const passDumpsByFunction = this.associateFullDumpsWithFunctions(raw_passes);
|
||||
// Match before / after pass dumps and we're done
|
||||
return this.matchPassDumps(passDumpsByFunction);
|
||||
}
|
||||
|
||||
// Further break down by functions in each dump
|
||||
const passDumps = raw_passes.map(this.breakdownPassDumpsIntoFunctions.bind(this));
|
||||
|
||||
// Transform array of passes containing multiple functions into a map from functions to arrays of passes on
|
||||
// those functions
|
||||
const passDumpsByFunction = this.breakdownIntoPassDumpsByFunction(passDumps);
|
||||
|
||||
// Match before / after pass dumps and we're done
|
||||
return this.matchPassDumps(passDumpsByFunction);
|
||||
}
|
||||
|
||||
applyIrFilters(ir: ResultLine[]) {
|
||||
return ir
|
||||
.filter(line => line.text.match(this.locationDefine) === null)
|
||||
.map(line => ({
|
||||
...line,
|
||||
text: line.text.replace(this.locationReference, ''),
|
||||
}));
|
||||
}
|
||||
|
||||
process(output: ResultLine[], _: ParseFiltersAndOutputOptions, optPipelineOptions: OptPipelineBackendOptions) {
|
||||
// Crop out any junk before the pass dumps
|
||||
const ir = output.slice(output.findIndex(line => this.irDumpHeader.test(line.text)));
|
||||
|
||||
const preprocessed_lines = this.applyIrFilters(ir);
|
||||
return this.breakdownOutput(preprocessed_lines, optPipelineOptions);
|
||||
}
|
||||
}
|
||||
BIN
public/logos/triton.png
Normal file
BIN
public/logos/triton.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.3 KiB |
@@ -24,6 +24,7 @@
|
||||
|
||||
import {describe, expect, it} from 'vitest';
|
||||
import {AsmParser} from '../lib/parsers/asm-parser.js';
|
||||
import {MlirAsmParser} from '../lib/parsers/asm-parser-mlir.js';
|
||||
import {PTXAsmParser} from '../lib/parsers/asm-parser-ptx.js';
|
||||
|
||||
describe('AsmParser tests', () => {
|
||||
@@ -196,3 +197,64 @@ vprintf,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('MlirAsmParser tests', () => {
|
||||
const parser = new MlirAsmParser();
|
||||
|
||||
describe('Location handling', () => {
|
||||
it('should process MLIR with location references', () => {
|
||||
const input = `
|
||||
#loc = loc("<source>":7:0)
|
||||
module {
|
||||
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("<source>":7:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("<source>":7:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("<source>":7:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("<source>":7:0)) attributes {noinline = false} {
|
||||
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
|
||||
%0 = tt.get_program_id x : i32 loc(#loc2)
|
||||
} loc(#loc)
|
||||
} loc(#loc)
|
||||
#loc1 = loc(unknown)
|
||||
#loc2 = loc("<source>":14:24)`;
|
||||
|
||||
const result = parser.processAsm(input, {});
|
||||
|
||||
// Verify location definitions are removed
|
||||
expect(result.asm.find(line => line.text.includes('#loc = loc'))).toBeUndefined();
|
||||
expect(result.asm.find(line => line.text.includes('#loc1 = loc'))).toBeUndefined();
|
||||
expect(result.asm.find(line => line.text.includes('#loc2 = loc'))).toBeUndefined();
|
||||
|
||||
// Verify location references are removed from displayed text
|
||||
expect(result.asm.find(line => line.text.includes('loc(#loc)'))).toBeUndefined();
|
||||
expect(result.asm.find(line => line.text.includes('loc(#loc1)'))).toBeUndefined();
|
||||
expect(result.asm.find(line => line.text.includes('loc(#loc2)'))).toBeUndefined();
|
||||
|
||||
// Verify inline locations are removed
|
||||
expect(result.asm.find(line => line.text.includes('loc("<source>"'))).toBeUndefined();
|
||||
|
||||
// Verify source information is correctly associated
|
||||
const programIdLine = result.asm.find(line => line.text.includes('tt.get_program_id'));
|
||||
expect(programIdLine).toBeDefined();
|
||||
expect(programIdLine?.source).toBeDefined();
|
||||
expect(programIdLine?.source?.file).toBe('<source>');
|
||||
expect(programIdLine?.source?.line).toBe(14);
|
||||
expect(programIdLine?.source?.column).toBe(24);
|
||||
|
||||
// Verify unknown locations are not associated with source information
|
||||
const constantLine = result.asm.find(line => line.text.includes('arith.constant'));
|
||||
expect(constantLine).toBeDefined();
|
||||
expect(constantLine?.source).toBeNull();
|
||||
|
||||
// Verify the structure is preserved
|
||||
const moduleStartLine = result.asm.find(line => line.text.includes('module {'));
|
||||
expect(moduleStartLine).toBeDefined();
|
||||
|
||||
const funcLine = result.asm.find(line => line.text.includes('tt.func public @add_kernel('));
|
||||
expect(funcLine).toBeDefined();
|
||||
expect(funcLine?.text?.includes('%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, ')).toBe(true);
|
||||
expect(funcLine?.text?.includes('%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, ')).toBe(true);
|
||||
expect(funcLine?.text?.includes('%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, ')).toBe(true);
|
||||
expect(funcLine?.text?.includes('%arg3: i32 {tt.divisibility = 16 : i32})')).toBe(true);
|
||||
|
||||
const constLine = result.asm.find(line => line.text.includes('arith.constant 1024'));
|
||||
expect(constLine).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
213
test/mlir-pass-dump-parser-tests.ts
Normal file
213
test/mlir-pass-dump-parser-tests.ts
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) 2025, Compiler Explorer Authors
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
// POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import {beforeAll, describe, expect, it} from 'vitest';
|
||||
|
||||
import {MlirPassDumpParser} from '../lib/parsers/mlir-pass-dump-parser.js';
|
||||
import * as properties from '../lib/properties.js';
|
||||
|
||||
function deepCopy<T>(obj: T): T {
|
||||
return JSON.parse(JSON.stringify(obj));
|
||||
}
|
||||
|
||||
describe('mlir-pass-dump-parser', () => {
|
||||
let mlirPassDumpParser: MlirPassDumpParser;
|
||||
|
||||
beforeAll(() => {
|
||||
const fakeProps = new properties.CompilerProps({mlir: {id: 'mlir'}}, properties.fakeProps({}));
|
||||
const compilerProps = (fakeProps.get as any).bind(fakeProps, 'mlir');
|
||||
mlirPassDumpParser = new MlirPassDumpParser(compilerProps);
|
||||
});
|
||||
|
||||
const rawMlirDump = [
|
||||
{
|
||||
text: "// -----// IR Dump Before Inliner (inline) ('builtin.module' operation) //----- //",
|
||||
},
|
||||
{text: 'module {'},
|
||||
{
|
||||
text: ' tt.func public @add_kernel() attributes {noinline = false} {',
|
||||
},
|
||||
{text: ' %0 = tt.get_program_id x : i32 loc(#loc1)'},
|
||||
{text: ' tt.return loc(#loc2)'},
|
||||
{text: ' } loc(#loc)'},
|
||||
{text: '} loc(#loc)'},
|
||||
{text: '#loc = loc("/app/example.py":7:0)'},
|
||||
{text: '#loc1 = loc("/app/example.py":8:24)'},
|
||||
{text: '#loc2 = loc("/app/example.py":8:4)'},
|
||||
{text: ''},
|
||||
{text: ''},
|
||||
{
|
||||
text: "// -----// IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel) //----- //",
|
||||
},
|
||||
{text: 'module {'},
|
||||
{
|
||||
text: ' tt.func public @add_kernel() attributes {noinline = false} {',
|
||||
},
|
||||
{text: ' %0 = tt.get_program_id x : i32 loc(#loc1)'},
|
||||
{text: ' tt.return loc(#loc2)'},
|
||||
{text: ' } loc(#loc)'},
|
||||
{text: '} loc(#loc)'},
|
||||
{text: '#loc = loc("/app/example.py":7:0)'},
|
||||
{text: '#loc1 = loc("/app/example.py":8:24)'},
|
||||
{text: '#loc2 = loc("/app/example.py":8:4)'},
|
||||
{text: ''},
|
||||
{text: ''},
|
||||
{
|
||||
text: "// -----// IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation) //----- //",
|
||||
},
|
||||
{text: 'module {'},
|
||||
{
|
||||
text: ' tt.func public @add_kernel() attributes {noinline = false} {',
|
||||
},
|
||||
{text: ' tt.return loc(#loc1)'},
|
||||
{text: ' } loc(#loc)'},
|
||||
{text: '} loc(#loc)'},
|
||||
{text: '#loc = loc("/app/example.py":7:0)'},
|
||||
{text: '#loc1 = loc("/app/example.py":8:4)'},
|
||||
{text: ''},
|
||||
{text: ''},
|
||||
];
|
||||
|
||||
it('should break down output into pass dumps', () => {
|
||||
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
|
||||
|
||||
expect(passDumps.length).toBe(3);
|
||||
expect(passDumps[0].header).toBe("IR Dump Before Inliner (inline) ('builtin.module' operation)");
|
||||
expect(passDumps[1].header).toBe(
|
||||
"IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)",
|
||||
);
|
||||
expect(passDumps[2].header).toBe(
|
||||
"IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation)",
|
||||
);
|
||||
|
||||
// Check that the first pass dump has the correct lines
|
||||
expect(passDumps[0].lines.length).toBe(10);
|
||||
expect(passDumps[0].lines[0].text).toBe('module {');
|
||||
expect(passDumps[0].lines[1].text).toBe(' tt.func public @add_kernel() attributes {noinline = false} {');
|
||||
});
|
||||
|
||||
it('should break down pass dumps into functions', () => {
|
||||
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
|
||||
const splitPassDump = mlirPassDumpParser.breakdownPassDumpsIntoFunctions(passDumps[0]);
|
||||
|
||||
expect(splitPassDump.header).toBe("IR Dump Before Inliner (inline) ('builtin.module' operation)");
|
||||
expect(Object.keys(splitPassDump.functions)).toContain('add_kernel');
|
||||
expect(splitPassDump.functions['add_kernel'].length).toBe(10);
|
||||
});
|
||||
|
||||
it('should apply IR filters to remove location information', () => {
|
||||
const filtered = mlirPassDumpParser.applyIrFilters(deepCopy(rawMlirDump.slice(0, 7)));
|
||||
|
||||
// Should filter out location references but keep the lines
|
||||
expect(filtered.length).toBe(7);
|
||||
expect(filtered[3].text).toBe(' %0 = tt.get_program_id x : i32');
|
||||
expect(filtered[4].text).toBe(' tt.return');
|
||||
expect(filtered[5].text).toBe(' }');
|
||||
expect(filtered[6].text).toBe('}');
|
||||
});
|
||||
|
||||
it('should break down output into pass dumps by function', () => {
|
||||
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
|
||||
const splitPassDumps = passDumps.map(dump => mlirPassDumpParser.breakdownPassDumpsIntoFunctions(dump));
|
||||
const passDumpsByFunction = mlirPassDumpParser.breakdownIntoPassDumpsByFunction(splitPassDumps);
|
||||
|
||||
expect(Object.keys(passDumpsByFunction)).toContain('add_kernel');
|
||||
expect(passDumpsByFunction['add_kernel'].length).toBe(3);
|
||||
|
||||
// Check that the function has all three passes
|
||||
const headers = passDumpsByFunction['add_kernel'].map(dump => dump.header);
|
||||
expect(headers).toContain("IR Dump Before Inliner (inline) ('builtin.module' operation)");
|
||||
expect(headers).toContain("IR Dump Before Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)");
|
||||
expect(headers).toContain(
|
||||
"IR Dump Before TritonRewriteTensorPointer (triton-rewrite-tensor-pointer) ('builtin.module' operation)",
|
||||
);
|
||||
});
|
||||
|
||||
it('should detect IR changes between passes', () => {
|
||||
// Create two different IR dumps to test change detection
|
||||
const before = [
|
||||
{text: 'module {'},
|
||||
{text: ' tt.func public @add_kernel() {'},
|
||||
{text: ' %0 = tt.get_program_id x : i32'},
|
||||
{text: ' tt.return'},
|
||||
{text: ' }'},
|
||||
{text: '}'},
|
||||
];
|
||||
|
||||
const afterNoChange = deepCopy(before);
|
||||
expect(mlirPassDumpParser.isIrChanged(before, afterNoChange)).toBe(false);
|
||||
|
||||
const afterWithChange = [
|
||||
{text: 'module {'},
|
||||
{text: ' tt.func public @add_kernel() {'},
|
||||
{text: ' tt.return'}, // Line removed
|
||||
{text: ' }'},
|
||||
{text: '}'},
|
||||
];
|
||||
expect(mlirPassDumpParser.isIrChanged(before, afterWithChange)).toBe(true);
|
||||
});
|
||||
|
||||
it('should match pass dumps and detect changes', () => {
|
||||
const passDumps = mlirPassDumpParser.breakdownOutputIntoPassDumps(deepCopy(rawMlirDump));
|
||||
const splitPassDumps = passDumps.map(dump => mlirPassDumpParser.breakdownPassDumpsIntoFunctions(dump));
|
||||
const passDumpsByFunction = mlirPassDumpParser.breakdownIntoPassDumpsByFunction(splitPassDumps);
|
||||
const matchedPasses = mlirPassDumpParser.matchPassDumps(passDumpsByFunction);
|
||||
|
||||
expect(Object.keys(matchedPasses)).toContain('add_kernel');
|
||||
|
||||
const addKernelPasses = matchedPasses['add_kernel'];
|
||||
expect(addKernelPasses.length).toBe(2); // We should have 2 passes (comparing the 3 dumps)
|
||||
|
||||
// Check the first pass (Inliner to Canonicalizer)
|
||||
expect(addKernelPasses[0].name).toBe("Inliner (inline) ('builtin.module' operation)");
|
||||
expect(addKernelPasses[0].irChanged).toBe(false); // No changes between these two dumps
|
||||
|
||||
// Check the second pass (Canonicalizer to TritonRewriteTensorPointer)
|
||||
expect(addKernelPasses[1].name).toBe("Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)");
|
||||
expect(addKernelPasses[1].irChanged).toBe(true); // There are changes between these two dumps
|
||||
});
|
||||
|
||||
it('should process the complete output correctly', () => {
|
||||
const optPipelineOptions = {
|
||||
fullModule: false,
|
||||
filterDebugInfo: false,
|
||||
filterIRMetadata: false,
|
||||
noDiscardValueNames: true,
|
||||
demangle: true,
|
||||
libraryFunctions: false,
|
||||
};
|
||||
const result = mlirPassDumpParser.process(deepCopy(rawMlirDump), {}, optPipelineOptions);
|
||||
|
||||
expect(Object.keys(result)).toContain('add_kernel');
|
||||
expect(result['add_kernel'].length).toBe(2);
|
||||
|
||||
// Verify the passes are correctly identified
|
||||
expect(result['add_kernel'][0].name).toBe("Inliner (inline) ('builtin.module' operation)");
|
||||
expect(result['add_kernel'][1].name).toBe("Canonicalizer (canonicalize) ('tt.func' operation: @add_kernel)");
|
||||
|
||||
// Verify the IR changes are correctly detected
|
||||
expect(result['add_kernel'][0].irChanged).toBe(false);
|
||||
expect(result['add_kernel'][1].irChanged).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -101,6 +101,7 @@ export type LanguageKey =
|
||||
| 'swift'
|
||||
| 'tablegen'
|
||||
| 'toit'
|
||||
| 'triton'
|
||||
| 'typescript'
|
||||
| 'v'
|
||||
| 'vala'
|
||||
|
||||
Reference in New Issue
Block a user