Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/tutorials/cli_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ Use exhaustive search on a tiling strategy limited to tile4d + only vectorized t
450/450 [22:30<00:00, 3.00s/it]
real 1352.37

Resume interrupted exploration while keeping reproducibility metadata:

# Start a run
loop-explore --search random --trials 200 --output results.random.csv

# Resume the same output (skips already recorded samples)
loop-explore --search random --trials 200 --output results.random.csv --resume

# Append regardless of duplicates
loop-explore --search random --trials 200 --output results.random.csv --append

Each run also writes `results.random.csv.meta.json` with the command arguments,
Python/runtime information, and git commit hash to simplify result provenance.

Test a single tiling:

# Dumps and execute MLIR tiling
Expand Down
149 changes: 142 additions & 7 deletions src/xtc/cli/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import logging
import itertools
import csv
import json
import random
from datetime import datetime, timezone
import numpy as np
import numpy.typing
from tqdm import tqdm
Expand Down Expand Up @@ -574,17 +576,113 @@ def peak_time(args: NS) -> float:
return time


def _args_to_metadata(args: NS) -> dict[str, Any]:
metadata_args: dict[str, Any] = {}
for key, value in vars(args).items():
if key == "eval_parameters":
continue
if isinstance(value, Path):
metadata_args[key] = str(value)
elif isinstance(value, (str, int, float, bool, list, dict)) or value is None:
metadata_args[key] = value
else:
metadata_args[key] = str(value)
return metadata_args


def _git_commit_hash() -> str | None:
try:
proc = subprocess.run(
["git", "rev-parse", "HEAD"],
check=True,
capture_output=True,
text=True,
)
return proc.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return None


def write_run_manifest(args: NS, strategy: Strategy) -> None:
output_path = Path(args.output)
metadata_path = Path(f"{args.output}.meta.json")
payload = {
"generatedAt": datetime.now(timezone.utc).isoformat(),
"output": str(output_path),
"resume": args.resume,
"append": args.append,
"sampleNames": strategy.sample_names,
"gitCommit": _git_commit_hash(),
"python": {
"version": sys.version,
"executable": sys.executable,
},
"platform": {
"platform": sys.platform,
"cwd": str(Path.cwd()),
},
"args": _args_to_metadata(args),
}
with open(metadata_path, "w", encoding="utf-8") as metadata_file:
json.dump(payload, metadata_file, indent=2, sort_keys=True)
metadata_file.write("\n")


class CSVCallback:
def __init__(self, fname: str, peak_time: float, sample_names: list[str]) -> None:
def __init__(
self,
fname: str,
peak_time: float,
sample_names: list[str],
*,
resume: bool = False,
append: bool = False,
) -> None:
self._fname = fname
self._peak_time = peak_time
self._outf = open(fname, "w", newline="")
self._sample_names = sample_names
self._header = sample_names + ["X", "time", "peak", "backend"]
self._results: list[Sequence] = []
self._rows: list[Sequence] = []
self._seen_keys: set[tuple[str, tuple[int, ...]]] = set()
self._resume = resume
self._append = append

out_path = Path(fname)
has_existing_file = out_path.exists() and out_path.stat().st_size > 0
if resume:
self._load_existing_rows()
mode = "a"
elif append:
mode = "a"
else:
mode = "w"

self._outf = open(fname, mode, newline="")
self._writer = csv.writer(self._outf, delimiter=",")
self._write_header()
self._results = []
self._rows = []

should_write_header = (not has_existing_file) or (mode == "w")
if should_write_header:
self._write_header()

def _load_existing_rows(self) -> None:
in_path = Path(self._fname)
if not in_path.exists() or in_path.stat().st_size == 0:
return
with open(in_path, newline="") as infile:
reader = csv.DictReader(infile, delimiter=",")
for row in reader:
backend = row.get("backend")
if backend is None:
continue
try:
sample = tuple(int(row[name]) for name in self._sample_names)
except (TypeError, ValueError, KeyError):
continue
self._seen_keys.add((backend, sample))

def _sample_key(self, x: Sample, backend: str) -> tuple[str, tuple[int, ...]]:
return backend, tuple(int(v) for v in x)

def _write_header(self) -> None:
self._writer.writerow(self._header)
Expand All @@ -594,19 +692,28 @@ def _write_row(self, row: Sequence) -> None:
self._rows.append(row)
self._writer.writerow(row)
self._outf.flush()
try:
os.fsync(self._outf.fileno())
except OSError:
logger.debug("Unable to fsync output file %s", self._fname)

def _write_result(self, result: Sequence) -> None:
self._results.append(result)
x, error, time, backend = result
if error != 0:
logger.debug(f"Skip recording error for: {backend}: {x}")
return
key = self._sample_key(x, backend)
if self._resume and key in self._seen_keys:
logger.debug("Skip already recorded sample for resume mode: %s", key)
return
peak = self._peak_time / time
s = str(x).replace(",", ";")
row = [s, time, peak, backend]
row = x + row
logger.debug(f"Record row: {row}")
self._write_row(row)
self._seen_keys.add(key)

def __call__(self, result: Sequence) -> None:
self._write_result(result)
Expand All @@ -627,7 +734,13 @@ def search_some(strategy: Strategy, graph: Graph, args: NS):
)
ptime = peak_time(args)
sample_names = strategy.sample_names
result_callback = CSVCallback(args.output, ptime, sample_names)
result_callback = CSVCallback(
args.output,
ptime,
sample_names,
resume=args.resume,
append=args.append,
)
callbacks = {
"result": result_callback,
"search": search_callback,
Expand Down Expand Up @@ -665,6 +778,7 @@ def optimize(args: NS):
op_args = (*dims, dtype)
graph = OPERATORS[args.operator]["operation"](*op_args, name=args.func_name)
strategy = get_strategy(graph, args)
write_run_manifest(args, strategy)
if args.test or args.opt_level in [0, 1, 2, 3]:
schedule = args.test
if not schedule:
Expand All @@ -679,7 +793,13 @@ def optimize(args: NS):
)
ptime = peak_time(args)
sample_names = strategy.sample_names
result_callback = CSVCallback(args.output, ptime, sample_names)
result_callback = CSVCallback(
args.output,
ptime,
sample_names,
resume=args.resume,
append=args.append,
)
callbacks = {
"result": result_callback,
"search": search_callback,
Expand Down Expand Up @@ -959,6 +1079,18 @@ def main():
parser.add_argument(
"--output", type=str, default="results.csv", help="output csv file for search"
)
parser.add_argument(
"--resume",
action=argparse.BooleanOptionalAction,
default=False,
help="resume from an existing output file and skip already recorded samples",
)
parser.add_argument(
"--append",
action=argparse.BooleanOptionalAction,
default=False,
help="append new results to output file without deduplication",
)
parser.add_argument(
"--eval", type=str, choices=["eval"], default="eval", help="evaluation method"
)
Expand Down Expand Up @@ -1040,6 +1172,9 @@ def main():
)
args = parser.parse_args()

if args.resume and args.append:
parser.error("--resume and --append cannot be used together")

logging.basicConfig()
logger.setLevel(logging.INFO)
if args.debug:
Expand Down
120 changes: 120 additions & 0 deletions tests/pytest/unit/cli/test_explore_csv_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import csv
from pathlib import Path

from xtc.cli.explore import CSVCallback


def _read_rows(path: Path) -> list[list[str]]:
with path.open(newline="") as infile:
return list(csv.reader(infile, delimiter=","))


def test_csv_callback_resume_dedup_skips_existing_and_keeps_new(tmp_path: Path):
output = tmp_path / "results.csv"
sample_names = ["M", "N"]

# Existing output already has sample [8, 16] for backend mlir.
with output.open("w", newline="") as out:
writer = csv.writer(out, delimiter=",")
writer.writerow(sample_names + ["X", "time", "peak", "backend"])
writer.writerow([8, 16, "[8; 16]", 0.2, 5.0, "mlir"])

cb = CSVCallback(
str(output),
peak_time=1.0,
sample_names=sample_names,
resume=True,
)

# Duplicate for same backend must be skipped in resume mode.
cb(([8, 16], 0, 0.2, "mlir"))
# Different backend for same sample must be recorded.
cb(([8, 16], 0, 0.25, "tvm"))
# Different sample must be recorded.
cb(([8, 32], 0, 0.3, "mlir"))
del cb

rows = _read_rows(output)

# Header + original row + two new rows.
assert len(rows) == 4
assert rows[0] == ["M", "N", "X", "time", "peak", "backend"]

# Check appended rows (order matters).
assert rows[2][0:2] == ["8", "16"]
assert rows[2][-1] == "tvm"
assert rows[3][0:2] == ["8", "32"]
assert rows[3][-1] == "mlir"


def test_csv_callback_default_mode_overwrites_existing_file(tmp_path: Path):
output = tmp_path / "results.csv"
sample_names = ["M", "N"]

with output.open("w", newline="") as out:
writer = csv.writer(out, delimiter=",")
writer.writerow(sample_names + ["X", "time", "peak", "backend"])
writer.writerow([9, 9, "[9; 9]", 0.9, 1.1, "mlir"])

cb = CSVCallback(
str(output),
peak_time=2.0,
sample_names=sample_names,
)

# Default mode is neither --resume nor --append, so file is rewritten.
cb(([2, 3], 0, 0.5, "mlir"))
del cb

rows = _read_rows(output)
assert rows[0] == ["M", "N", "X", "time", "peak", "backend"]
assert len(rows) == 2
assert rows[1][0:2] == ["2", "3"]
assert rows[1][-1] == "mlir"


def test_csv_callback_append_mode_allows_duplicates(tmp_path: Path):
output = tmp_path / "results.csv"
sample_names = ["M", "N"]

with output.open("w", newline="") as out:
writer = csv.writer(out, delimiter=",")
writer.writerow(sample_names + ["X", "time", "peak", "backend"])
writer.writerow([4, 4, "[4; 4]", 0.1, 10.0, "mlir"])

cb = CSVCallback(
str(output),
peak_time=1.0,
sample_names=sample_names,
append=True,
)

# Same sample/backend is allowed in append mode.
cb(([4, 4], 0, 0.12, "mlir"))
del cb

rows = _read_rows(output)

# Header + original row + duplicate appended row.
assert len(rows) == 3
assert rows[1][0:2] == ["4", "4"]
assert rows[2][0:2] == ["4", "4"]
assert rows[1][-1] == "mlir"
assert rows[2][-1] == "mlir"


def test_csv_callback_append_writes_header_for_new_file(tmp_path: Path):
output = tmp_path / "results.csv"

cb = CSVCallback(
str(output),
peak_time=2.0,
sample_names=["M", "N"],
append=True,
)
cb(([2, 3], 0, 0.5, "mlir"))
del cb

rows = _read_rows(output)
assert rows[0] == ["M", "N", "X", "time", "peak", "backend"]
assert rows[1][0:2] == ["2", "3"]