Skip to content

Kerncap

Kerncap profiles a running application, intercepts a target kernel dispatch, captures its complete runtime state (full device memory snapshot, kernarg buffer, HSACO), and generates a standalone reproducer that can replay the kernel in isolation using VA-faithful HSA dispatch.

How it works

1. Profile rocprofv3 --kernel-trace --stats → rank kernels by duration
2. Capture HIP: LD_PRELOAD=libkerncap.so → intercept target dispatch,
snapshot all tracked device memory + kernarg buffer + HSACO
Triton: Python-level hook on JITFunction.run → capture all tensor,
scalar, and constexpr args; pin autotuner config
3. Find source HIP: __global__ grep + #include tracing
Triton: @triton.jit AST match + import tracing (incl. relative imports)
4. Generate Jinja2 templates → standalone .hip+Makefile or .py reproducer
5. Validate Build, run reproducer, np.allclose against captured reference

Installation

Builds libkerncap.so from source against the host ROCm 7.0+ (requires hipcc, cmake, HSA headers, rocprofiler-sdk — all present in standard ROCm images).

Terminal window
# From local source
pip install .
# Editable install for development
pip install -e .[dev]

Usage

Each operation is available as both a Python API and a CLI command.

Profile

Rank kernels by total GPU execution time.

from kerncap import Kerncap
kc = Kerncap()
profile = kc.profile(["./my_app", "--args"])
for kernel in profile[:5]:
print(f"{kernel.name}: {kernel.total_duration_ns / 1e6:.1f} ms ({kernel.percentage:.1f}%)")

Extract

Capture a kernel’s full runtime state and generate a standalone reproducer.

# HIP kernel with source (enables recompile workflow)
result = kc.extract(
kernel_name="mul_mat_q",
cmd=["./llama-bench", "-m", "model.gguf", "-p", "512"],
source_dir="./ggml/src",
output="./isolated/mul_mat_q",
defines=["GGML_USE_HIP", "GGML_CUDA_FA_ALL_QUANTS"],
)
print(f"Output: {result.output_dir} has_source: {result.has_source}")
# Triton kernel — language auto-detected from source
result = kc.extract(
kernel_name="flash_attn_fwd",
cmd=["python", "train.py", "--batch-size", "64"],
source_dir="./flash_attn",
output="./isolated/flash_attn_fwd",
)

Language detection: Kerncap auto-detects whether a kernel is HIP or Triton from --source-dir contents. To override, pass --language hip or --language triton on the CLI (or language="triton" in the Python API).

Replay

Replay a captured kernel in isolation.

# Replay with captured HSACO
baseline = kc.replay("./isolated/mul_mat_q")
print(f"Baseline: {baseline.timing_us:.1f} us")
# Replay with a variant HSACO and compare
variant = kc.replay("./isolated/mul_mat_q", hsaco="./isolated/mul_mat_q/optimized.hsaco")
print(f"Variant: {variant.timing_us:.1f} us")
print(f"Speedup: {baseline.timing_us / variant.timing_us:.2f}x")

HIP launch mode: If replay conflicts with rocprofv3 (e.g. when profiling the reproducer itself), pass --hip-launch to use the HIP runtime launch path instead of the default HSA dispatch.

Validate

Check correctness of a reproducer or variant HSACO.

# Smoke test — confirm baseline replays without error
result = kc.validate("./isolated/mul_mat_q")
print("Passed:", result.passed)
# Correctness check — compare recompiled variant against captured baseline
result = kc.validate("./isolated/mul_mat_q", hsaco="./isolated/mul_mat_q/optimized.hsaco")
print("Passed:", result.passed)
# Triton — compare against captured reference with tolerance
result = kc.validate("./isolated/flash_attn_fwd", tolerance=1e-3, rtol=1e-2)
print("Passed:", result.passed)

HIP vs Triton validation: For HIP kernels, baseline validate is a smoke test only. Pass hsaco to compare a recompiled variant byte-for-byte against the captured baseline. For Triton reproducers, validate compares outputs against captured reference data using np.allclose.

Extract methodology

The extract stage takes a kernel name and a runnable command and produces a fully self-contained reproducer project. Under the hood it runs three sub-stages — capture, find source, generate — each with language-specific paths for HIP and Triton.

Capture

Snapshots the full runtime state of a single kernel dispatch for later replay.

HIP kernels are captured at the HSA level. libkerncap.so (loaded via LD_PRELOAD with rocprofiler-sdk registration) hooks hsa_queue_create to install a packet intercept callback. When the target dispatch arrives, kerncap interposes a completion signal, waits for the kernel to finish, then walks the kernarg buffer. All device memory allocations are tracked via hsa_amd_memory_pool_allocate and hsa_amd_vmem_* hooks. At capture time, a full device memory snapshot is taken — every tracked allocation is D2H copied. The replay binary restores all memory at the original virtual addresses using HSA VMEM APIs, then dispatches the kernel with the captured HSACO. No DWARF metadata or argument parsing needed.

Triton kernels are captured at the Python level via monkey-patching JITFunction.run and Autotuner.run. Inputs (tensors, scalars, constexprs) are serialized before launch and reference outputs saved after. For autotuned kernels the winning config is recorded so the reproducer can pin it exactly.

Find source

Locates the kernel’s source so the reproducer can compile (HIP) or import (Triton) it.

HIP: Searches the source tree for __global__ declarations matching the demangled kernel name, then traces local #include "..." directives recursively (depth 5) to collect all required headers. For projects built with CMake, kerncap uses compile_commands.json to find the translation unit and extract the exact compiler flags needed for the make recompile target.

Note: compile_commands.json is optional — the capture/replay/validate workflow works without it. However, the make recompile target (for editing and rebuilding kernels) requires it. Generate it by adding -DCMAKE_EXPORT_COMPILE_COMMANDS=ON when configuring your CMake project. If missing, you’ll see: “No compile command found (compile_commands.json missing or has no entry for this file). The ‘make recompile’ target will not be available.” — basic validation still works, but you won’t be able to test kernel modifications.

Triton: Parses Python files under --source-dir with the ast module, matching @triton.jit/@triton.autotune decorators. ImportFrom nodes (including relative imports) are traced to resolve the full dependency set.

Generate

The captured data and located source files are assembled into a standalone project using Jinja2 templates.

HIP kernels produce a VA-faithful replay project using kerncap-replay. The captured HSACO, kernarg buffer, and full device memory snapshot are stored in capture/. make run replays the kernel at the original virtual addresses using HSA VMEM APIs — no kernel source compilation needed.

When --source-dir is provided, kerncap additionally finds the .cu translation unit (via compile_commands.json or reverse-include search) and produces:

  • kernel_variant.cpp — a copy of the main kernel source file for editing
  • deps/ — copies of all #include dependency headers (traced up to 5 levels deep)
  • vfs.yaml — a Clang Virtual File System overlay that maps all local copies over the originals during recompilation

This enables the optimization workflow below.

OutputAlwaysWith --source-dir
Captured state (capture/)YesYes
Editable source (kernel_variant.cpp)Yes
Dependency headers (deps/)Yes (when deps exist)
VFS overlay (vfs.yaml)Yes
MakefileYesYes

Triton kernels produce a reproducer.py that imports the kernel from the copied source tree, loads tensor arguments from binary dumps, and calls the kernel. For autotuned kernels, the reproducer calls kernel.fn directly with pinned config kwargs, bypassing re-tuning entirely (see Triton autotuner reproducibility).

Optimization workflow

When source_dir is provided, extract produces a self-contained project for a tight edit-recompile-validate loop:

kernel_variant.cpp Editable copy of the main kernel source file
deps/ Copies of all #include dependency headers (up to 5 levels)
vfs.yaml Clang VFS overlay — maps local copies over originals at compile time
capture/ VA-faithful memory snapshot, dispatch metadata, baseline HSACO
Makefile make run | make recompile | make run-variant | make validate-variant

Prerequisites for make recompile: The recompilation workflow requires compile_commands.json from your project’s build directory. If your project was built with CMake, regenerate with -DCMAKE_EXPORT_COMPILE_COMMANDS=ON to enable this. Without it, make run and kerncap validate still work (they use the captured HSACO), but make recompile won’t be available.

import subprocess, os
from kerncap import Kerncap
kc = Kerncap()
# 1. Extract (once)
result = kc.extract("mul_mat_q", cmd=[...], source_dir="./ggml/src", output="./isolated/mul_mat_q")
reproducer_dir = result.output_dir
# 2. Edit kernel_variant.cpp or files in deps/ (do not change the kernel signature)
# 3. Recompile — single kernel, no application rebuild
subprocess.run(["make", "recompile"], cwd=reproducer_dir, check=True)
# 4. Compare baseline vs variant
baseline = kc.replay(reproducer_dir)
variant = kc.replay(reproducer_dir, hsaco=os.path.join(reproducer_dir, "optimized.hsaco"))
print(f"Baseline: {baseline.timing_us:.1f} us Variant: {variant.timing_us:.1f} us")
print(f"Speedup: {baseline.timing_us / variant.timing_us:.2f}x")
# 5. Validate correctness
result = kc.validate(reproducer_dir, hsaco=os.path.join(reproducer_dir, "optimized.hsaco"))
print("Passed:", result.passed)

AI-assisted optimization

The Python API is designed for LLM-driven workflows. A Cursor agent (or any LLM with code execution) can drive the full pipeline — profile, extract, recompile, benchmark, validate — entirely through the Kerncap class without shell scripting. The key inputs to provide are:

  • The application command (cmd)
  • The source directory (source_dir) and any preprocessor defines (defines)
  • The reproducer directory for the edit-recompile-validate loop

capture/dispatch.json is particularly useful context for an LLM: it contains the kernel’s launch configuration (grid/block dims, kernarg layout, GPU architecture) alongside the captured HSACO, giving a complete picture of what the kernel does and how it is launched before any source is read.

Technical details

Embedded device pointers

Kerncap uses VA-faithful replay: all device memory is captured in a full snapshot and restored at the original virtual addresses during replay. Embedded device pointers (e.g. T** in batched BLAS, structs with pointer members) work automatically — no pointer patching or relocation tables needed.

Triton autotuner reproducibility

Triton’s @triton.autotune selects a config by benchmarking (e.g. BLOCK_M=128, num_warps=4). Different configs change FP accumulation order, which can cause large numerical differences in FP16. Kerncap captures the winning config and pins it in the reproducer, bypassing re-tuning entirely.

If validation fails with tight tolerances, use kerncap validate --tolerance <atol> to relax the threshold.

NaN in validation output: Common causes are uninitialized device memory, FP16 overflow, or wrong dtype inference. The validator reports NaN counts per argument and sets max_error to nan.

Validation targets

  • Triton: Flash Attention forward kernel (ROCm/flash-attention, Triton backend) in rocm/pytorch container
  • HIP: Composable Kernel GEMM XDL FP16 (ROCm/composable_kernel) in rocm/composable_kernel:ck_pytorch container
  • HIP (embedded pointers): Batched vector scale kernel in local ROCm environment, testing T** (double-pointer) arguments via VA-faithful replay

llama.cpp/ggml kernels (template-qualified names like mul_mat_q<(ggml_type)7, 32, true>) are also supported via the -D flag for preprocessor defines.

Project structure

src/kerncap.{hip,hpp} HSA tool loaded via LD_PRELOAD (rocprofiler-sdk registration)
src/replay.cpp VA-faithful HSA kernel replay binary (kerncap-replay)
kerncap/ Python package (CLI, profiler, capturer, source finder,
reproducer generator, validator)
kerncap/templates/ Jinja2 templates for HIP and Triton reproducers
vendor/ Vendored nlohmann/json headers
tests/ Unit + integration tests