diff --git a/harness/cifar10/__pycache__/cifar10.cpython-312.pyc b/harness/cifar10/__pycache__/cifar10.cpython-312.pyc deleted file mode 100644 index 8f9b69d..0000000 Binary files a/harness/cifar10/__pycache__/cifar10.cpython-312.pyc and /dev/null differ diff --git a/harness/cifar10/__pycache__/resnet20.cpython-312.pyc b/harness/cifar10/__pycache__/resnet20.cpython-312.pyc deleted file mode 100644 index daf88e7..0000000 Binary files a/harness/cifar10/__pycache__/resnet20.cpython-312.pyc and /dev/null differ diff --git a/harness/cifar10/__pycache__/test.cpython-312.pyc b/harness/cifar10/__pycache__/test.cpython-312.pyc deleted file mode 100644 index 02603b6..0000000 Binary files a/harness/cifar10/__pycache__/test.cpython-312.pyc and /dev/null differ diff --git a/harness/cifar10/__pycache__/train.cpython-312.pyc b/harness/cifar10/__pycache__/train.cpython-312.pyc deleted file mode 100644 index 7516026..0000000 Binary files a/harness/cifar10/__pycache__/train.cpython-312.pyc and /dev/null differ diff --git a/harness/cifar10/cifar10.py b/harness/cifar10/cifar10.py new file mode 100644 index 0000000..4d49f22 --- /dev/null +++ b/harness/cifar10/cifar10.py @@ -0,0 +1,259 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import numpy as np +import warnings +from torchvision import datasets, transforms +from torch.utils.data import DataLoader, random_split +from absl import app, flags +import os +import sys + +# Add the current directory to the Python path +sys.path.insert(0, os.path.dirname(__file__)) +from resnet20 import ResNet20 +import train +import test + +FLAGS = flags.FLAGS + +# 1. Configuration +BATCH_SIZE = 64 +LEARNING_RATE = 0.001 +WEIGHT_DECAY = 1e-4 +EPOCHS = 350 # Increased epochs for potentially better accuracy +MODEL_PATH = './harness/cifar10/cifar10_resnet20_model.pth' +RNG_SEED = 42 # for reproducibility +DATA_DIR='./harness/cifar10/data' + +# Define command line flags safely to allow importing this module from other apps +try: + flags.DEFINE_string('model_path', MODEL_PATH, 'Path to save/load the model') + flags.DEFINE_integer('batch_size', BATCH_SIZE, 'Batch size for training and evaluation') + flags.DEFINE_float('learning_rate', LEARNING_RATE, 'Learning rate for optimizer') + flags.DEFINE_float('weight_decay', WEIGHT_DECAY, 'Weight decay for optimizer') + flags.DEFINE_integer('epochs', EPOCHS, 'Number of training epochs') + flags.DEFINE_string('data_dir', './harness/cifar10/data', 'Directory to store/load CIFAR10 dataset') + flags.DEFINE_boolean('no_cuda', False, 'Disable CUDA even if available') + flags.DEFINE_integer('seed', RNG_SEED, 'Random seed for reproducibility') + + flags.DEFINE_boolean('export_test_data', False, 'Export test dataset to file and exit') + flags.DEFINE_string('test_data_output', 'cifar10_test.txt', 'Output file for exported test data') + flags.DEFINE_integer('num_samples', -1, 'Number of samples to export (-1 for all samples)') + + flags.DEFINE_boolean('predict', False, 'Run prediction on pixels file and exit') + flags.DEFINE_string('pixels_file', '', 'Path to file containing pixel data for prediction') + flags.DEFINE_string('predictions_file', 'predictions.txt', 'Output file for predictions') +except flags.DuplicateFlagError: + pass + +# Ensure reproducibility +torch.manual_seed(RNG_SEED) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(RNG_SEED) + +# 2. Data Loading and Preprocessing +def get_cifar10_transform(transform_type="validation"): + """ + Get the standard CIFAR10 transform for preprocessing. + + Returns: + transforms.Compose: Transform pipeline for CIFAR10 data + """ + if transform_type == "train": + return transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + ]) + else: + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) + ]) + +def load_and_preprocess_data(batch_size=BATCH_SIZE, data_dir=DATA_DIR): + """ + Load and preprocess CIFAR10 dataset. + + Args: + batch_size (int): Batch size for data loaders + data_dir (str): Directory to store/load dataset + + Returns: + tuple: (train_loader, val_loader, test_loader) + """ + train_transform = get_cifar10_transform(transform_type="train") + test_transform = get_cifar10_transform() + + # Download CIFAR10 dataset. Suppress numpy VisibleDeprecationWarning + # originating from torchvision's CIFAR pickle loader on some numpy versions. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=getattr(np, 'VisibleDeprecationWarning', DeprecationWarning)) + full_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=train_transform) + test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, transform=test_transform) + + # Split training data into training and validation sets + train_size = int(0.8 * len(full_dataset)) + val_size = len(full_dataset) - train_size + train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + + return train_loader, val_loader, test_loader + + +# 4. Training Function: See train.py +def train_model(model_path, batch_size, learning_rate, weight_decay, epochs, train_loader, val_loader, data_dir, device): + """ + Build or load a ResNet20 model, train if necessary, and return the model. + Mirrors the `mnist.train_model` signature and behavior. + """ + channel_values = [16, 32, 64] + num_classes = 10 + + model = ResNet20(channel_values, num_classes).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + # If model exists, load it; otherwise train and save + if os.path.exists(model_path): + model.load_state_dict(torch.load(model_path, map_location=device)) + else: + # Train using the helper in harness/cifar10/train.py + train.train_model_function(model, train_loader, criterion, optimizer, num_epochs=epochs, device=device) + # Ensure directory exists when saving + model_dir = os.path.dirname(model_path) + if model_dir and not os.path.exists(model_dir): + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_path) + print(f"Model saved to {model_path}") + + return model + + +# 5. Testing Function: See test.py + +def run_predict(model_path, pixels_file, predictions_file, device="cpu"): + """ + Run prediction on the given pixel file using the specified model. + """ + # If model file doesn't exist, train and save it + if not os.path.exists(model_path): + train_loader, val_loader, test_loader = load_and_preprocess_data(batch_size=BATCH_SIZE, data_dir=DATA_DIR) + _ = train_model(model_path, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY, + epochs=EPOCHS, train_loader=train_loader, val_loader=val_loader, + data_dir=DATA_DIR, device=device) + + # Determine saved model path + saved_model_path = model_path + if os.path.isdir(model_path): + saved_model_path = os.path.join(model_path, 'cifar10_resnet20_model.pth') if not model_path.endswith('.pth') else model_path + + test.predict(pixels_file, saved_model_path, predictions_file, device=device) + + +def export_test_pixels_labels(data_dir=DATA_DIR, pixels_file="cifar10_pixels.txt", labels_file="cifar10_labels.txt", num_samples=-1, seed=None): + """ + Export CIFAR10 test dataset to separate label and pixel files using random sampling. + + Args: + data_dir (str): Directory to download dataset temporarily. + pixels_file (str): Path to the output file for pixel values + labels_file (str): Path to the output file for labels + num_samples (int): Number of samples to export (-1 for all) + """ + if seed is not None: + np.random.seed(seed) + + print("Loading CIFAR-10 test data via torchvision...") + transform = transforms.ToTensor() + test_dataset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform) + + total_samples = len(test_dataset) + samples_to_export = total_samples if num_samples == -1 else min(num_samples, total_samples) + + # Use sample_test_data to get random samples (but without normalization for export) + if samples_to_export == total_samples: + + with open(labels_file, 'w') as label_f, open(pixels_file, 'w') as pixel_f: + for image, label in test_dataset: + flattened_image = image.view(-1).numpy() + label_f.write(f"{label}\n") + pixel_values = " ".join(f"{pixel:.6f}" for pixel in flattened_image) + pixel_f.write(f"{pixel_values}\n") + else: + # Generate random indices and Create a subset dataset using the random indices + random_indices = torch.randperm(total_samples)[:samples_to_export] + subset_dataset = torch.utils.data.Subset(test_dataset, random_indices) + subset_loader = DataLoader(subset_dataset, batch_size=1, shuffle=False) + + with open(labels_file, 'w') as label_f, open(pixels_file, 'w') as pixel_f: + for batch_images, batch_labels in subset_loader: + for image, label in zip(batch_images, batch_labels): + flattened_image = image.view(-1).numpy() + label_f.write(f"{label.item()}\n") + pixel_values = " ".join(f"{pixel:.6f}" for pixel in flattened_image) + pixel_f.write(f"{pixel_values}\n") + + +def export_test_data(data_dir=DATA_DIR, output_file='cifar10_test.txt', num_samples=-1, seed=None): + """ + Export CIFAR10 test dataset to separate label and pixel files using random sampling. + + Args: + data_dir (str): Directory to load dataset from + output_file (str): Base output file path (will create .labels and .pixels files) + num_samples (int): Number of samples to export (-1 for all) + """ + # Create separate file names for labels and pixels + base_name = str(output_file).rsplit('.', 1)[0] if '.' in str(output_file) else str(output_file) + labels_file = f"{base_name}_labels.txt" + pixels_file = f"{base_name}_pixels.txt" + export_test_pixels_labels(data_dir=data_dir, pixels_file=pixels_file, labels_file=labels_file, num_samples=num_samples, seed=seed) + + + +def main(argv): + # Check if we should just export test data and exit + if FLAGS.export_test_data: + print("Export mode: Loading and exporting test data...") + export_test_data(data_dir=FLAGS.data_dir, output_file=FLAGS.test_data_output, num_samples=FLAGS.num_samples) + print("Export completed. Exiting.") + return + + use_cuda = not FLAGS.no_cuda and torch.cuda.is_available() + random_seed = FLAGS.seed + # Set random seed for reproducibility + torch.manual_seed(random_seed) + + if use_cuda: + torch.cuda.manual_seed_all(random_seed) + device = "cuda" if use_cuda else "cpu" + # Train the model. + train_loader, val_loader, test_loader = load_and_preprocess_data(batch_size=FLAGS.batch_size, data_dir=FLAGS.data_dir) + model = train_model(FLAGS.model_path, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.weight_decay, + FLAGS.epochs, train_loader, val_loader, data_dir=FLAGS.data_dir, device=device) + + # Check if we should run prediction and exit + if FLAGS.predict: + if not FLAGS.pixels_file: + print("Error: pixels_file must be specified when using --predict flag") + return + print("Prediction mode: Running inference on provided pixel data...") + run_predict(FLAGS.model_path, FLAGS.pixels_file, FLAGS.predictions_file, device=device) + print("Prediction completed. Exiting.") + return + else: + # Testing the model + print(f"\nEvaluating model on test data...") + test.test_model(model, test_loader, device) + +if __name__ == '__main__': + app.run(main) diff --git a/harness/cifar10/resnet20.py b/harness/cifar10/resnet20.py new file mode 100755 index 0000000..b98bcdf --- /dev/null +++ b/harness/cifar10/resnet20.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.shortcut_conv = None + self.shortcut_bn = None + if stride != 1 or in_channels != out_channels: + self.shortcut_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.shortcut_bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + shortcut = x + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.shortcut_conv is not None: + shortcut = self.shortcut_conv(x) + shortcut = self.shortcut_bn(shortcut) + out += shortcut + out = F.relu(out) + return out + + +class ResNet20(nn.Module): + """Compact ResNet-20 style model definition. + + Args: + channel_values (list): list of three channel widths for each stage (e.g. [16,32,64]) + num_classes (int): number of output classes + """ + def __init__(self, channel_values, num_classes=10): + super(ResNet20, self).__init__() + self.conv1 = nn.Conv2d(3, channel_values[0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(channel_values[0]) + + # Stage 1 + self.layer1_block1 = BasicBlock(channel_values[0], channel_values[0]) + self.layer1_block2 = BasicBlock(channel_values[0], channel_values[0]) + self.layer1_block3 = BasicBlock(channel_values[0], channel_values[0]) + + # Stage 2 + self.layer2_block1 = BasicBlock(channel_values[0], channel_values[1], stride=2) + self.layer2_block2 = BasicBlock(channel_values[1], channel_values[1]) + self.layer2_block3 = BasicBlock(channel_values[1], channel_values[1]) + + # Stage 3 + self.layer3_block1 = BasicBlock(channel_values[1], channel_values[2], stride=2) + self.layer3_block2 = BasicBlock(channel_values[2], channel_values[2]) + self.layer3_block3 = BasicBlock(channel_values[2], channel_values[2]) + + self.avgpool = nn.AvgPool2d(kernel_size=8) + self.fc = nn.Linear(channel_values[2], num_classes) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + + x = self.layer1_block1(x) + x = self.layer1_block2(x) + x = self.layer1_block3(x) + + x = self.layer2_block1(x) + x = self.layer2_block2(x) + x = self.layer2_block3(x) + + x = self.layer3_block1(x) + x = self.layer3_block2(x) + x = self.layer3_block3(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x \ No newline at end of file diff --git a/harness/cifar10/test.py b/harness/cifar10/test.py new file mode 100644 index 0000000..d96bdd3 --- /dev/null +++ b/harness/cifar10/test.py @@ -0,0 +1,74 @@ +import os +import torch +from resnet20 import ResNet20 +import numpy as np + + +def test_model(model, test_loader, device): + model.eval() # Set model to evaluation mode + correct = 0 + total = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + _, predicted = torch.max(output.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + accuracy = 100 * correct / total if total > 0 else 0.0 + print(f'Accuracy on test data: {accuracy:.2f}%') + return accuracy + + +def predict(pixels_file, model_path="cifar10_resnet20_model.pth", predictions_file='predictions.txt', device='cpu'): + """ + Load a trained ResNet20 and make predictions on pixel data from a file. + + Each line in `pixels_file` should contain 3072 float values (3*32*32), space-separated. + """ + device = torch.device(device) + + channel_values = [16, 32, 64] + num_classes = 10 + model = ResNet20(channel_values, num_classes).to(device) + + if not os.path.exists(model_path): + raise FileNotFoundError(f"Model file not found: {model_path}") + + model.load_state_dict(torch.load(model_path, map_location=device)) + model.eval() + + # CIFAR normalization + mean = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32) + std = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32) + + pixel_data = [] + with open(pixels_file, 'r') as f: + for line in f: + vals = [float(x) for x in line.strip().split() if x] + if len(vals) != 3 * 32 * 32: + raise ValueError(f"Each line must contain exactly 3072 pixel values, got {len(vals)}") + pixel_data.append(vals) + + if not pixel_data: + return [] + + tensors = [] + for vals in pixel_data: + arr = np.asarray(vals, dtype=np.float32).reshape((3, 32, 32)) + for c in range(3): + arr[c] = (arr[c] - mean[c]) / std[c] + tensors.append(torch.from_numpy(arr)) + + batch = torch.stack(tensors).to(device) + + with torch.no_grad(): + outputs = model(batch) + _, predicted = torch.max(outputs, 1) + predictions = predicted.cpu().numpy().tolist() + + with open(predictions_file, 'w') as out: + for p in predictions: + out.write(f"{p}\n") + + return predictions diff --git a/harness/cifar10/train.py b/harness/cifar10/train.py new file mode 100755 index 0000000..e830eac --- /dev/null +++ b/harness/cifar10/train.py @@ -0,0 +1,57 @@ + +import torch +import tqdm +import os + +# This is the training function. Train, show the loss and accuracy for every epoch. +def train_model_function(model, train_loader, criterion, optimizer, num_epochs=10, device='cpu'): + model.to(device) # Move the model to the specified device (GPU or CPU) + + # Lists to store loss and accuracy for each epoch + epoch_losses = [] + epoch_accuracies = [] + schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max= num_epochs, eta_min=0) + + # Training loop + for epoch in range(num_epochs): + model.train() + running_loss = 0.0 + correct = 0 + total = 0 + + # Create a tqdm progress bar for the training process + progress_bar = tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{num_epochs}') + + # Iterate through the batches of the training dataset + for batch_idx, (inputs, targets) in progress_bar: + + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + # Accumulate loss for the epoch and calculate the + running_loss += loss.item() + _, predicted = torch.max(outputs, 1) + correct += (predicted == targets).sum().item() + total += targets.size(0) + + # Calculate average loss and accuracy for this epoch + epoch_loss = running_loss / len(train_loader) + epoch_accuracy = correct / total * 100 # Convert to percentage + + # Store the loss and accuracy + epoch_losses.append(epoch_loss) + epoch_accuracies.append(epoch_accuracy) + schedular.step() + + # Print statistics for each epoch + print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%') + + print('Training completed!') + + # Return the losses and accuracies + return epoch_losses, epoch_accuracies + diff --git a/harness/cleartext_impl.py b/harness/cleartext_impl.py index 289ba26..8591b95 100644 --- a/harness/cleartext_impl.py +++ b/harness/cleartext_impl.py @@ -22,20 +22,37 @@ from pathlib import Path from utils import parse_submission_arguments from mnist import mnist +from cifar10 import cifar10 def main(): """ Usage: python3 cleartext_impl.py """ - if len(sys.argv) != 3: - sys.exit("Usage: cleartext_impl.py ") + if len(sys.argv) != 4: + sys.exit("Usage: cleartext_impl.py ") + + # Paths to trained models + minst_path = "harness/mnist/mnist_ffnn_model.pth" + cifar10_path = "harness/cifar10/cifar10_resnet20_model.pth" INPUT_PATH = Path(sys.argv[1]) OUTPUT_PATH = Path(sys.argv[2]) - model_path = "harness/mnist/mnist_ffnn_model.pth" + DATASET_NAME = sys.argv[3] - mnist.run_predict(model_path=model_path, pixels_file=INPUT_PATH, predictions_file=OUTPUT_PATH) + match DATASET_NAME: + case "mnist": + return mnist.run_predict( + model_path=minst_path, + pixels_file=INPUT_PATH, + predictions_file=OUTPUT_PATH) + case "cifar10": + return cifar10.run_predict( + model_path=cifar10_path, + pixels_file=INPUT_PATH, + predictions_file=OUTPUT_PATH) + case _: + raise ValueError(f"Unsupported dataset name: {DATASET_NAME}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/harness/generate_dataset.py b/harness/generate_dataset.py index 6887eb1..b192c5e 100644 --- a/harness/generate_dataset.py +++ b/harness/generate_dataset.py @@ -19,6 +19,7 @@ import sys from pathlib import Path from mnist import mnist +from cifar10 import cifar10 def main(): """ @@ -32,11 +33,19 @@ def main(): DATASET_NAME = sys.argv[2] DATASET_PATH.parent.mkdir(parents=True, exist_ok=True) - if DATASET_NAME == "mnist": - mnist.export_test_data(output_file=DATASET_PATH, num_samples=10000, seed=None) - else: - raise ValueError(f"Unsupported dataset name: {DATASET_NAME}") - + match DATASET_NAME: + case "mnist": + return mnist.export_test_data( + output_file=DATASET_PATH, + num_samples=10000, + seed=None) + case "cifar10": + return cifar10.export_test_data( + output_file=DATASET_PATH, + num_samples=10000, + seed=None) + case _: + raise ValueError(f"Unsupported dataset name: {DATASET_NAME}") if __name__ == "__main__": main() diff --git a/harness/generate_input.py b/harness/generate_input.py index bd61d74..240e9fe 100644 --- a/harness/generate_input.py +++ b/harness/generate_input.py @@ -19,6 +19,7 @@ import utils from utils import parse_submission_arguments from mnist import mnist +from cifar10 import cifar10 def main(): @@ -39,6 +40,13 @@ def main(): labels_file=LABELS_PATH, num_samples=num_samples, seed=seed) + case "cifar10": + return cifar10.export_test_pixels_labels( + data_dir = params.datadir(), + pixels_file=PIXELS_PATH, + labels_file=LABELS_PATH, + num_samples=num_samples, + seed=seed) case _: raise ValueError(f"Unsupported dataset name: {dataset_name}") diff --git a/harness/mnist/mnist.py b/harness/mnist/mnist.py index 3ef0e80..b340233 100644 --- a/harness/mnist/mnist.py +++ b/harness/mnist/mnist.py @@ -24,22 +24,25 @@ RNG_SEED = 42 # for reproducibility DATA_DIR='./harness/mnist/data' -# Define command line flags -flags.DEFINE_string('model_path', MODEL_PATH, 'Path to save/load the model') -flags.DEFINE_integer('batch_size', BATCH_SIZE, 'Batch size for training and evaluation') -flags.DEFINE_float('learning_rate', LEARNING_RATE, 'Learning rate for optimizer') -flags.DEFINE_integer('epochs', EPOCHS, 'Number of training epochs') -flags.DEFINE_string('data_dir', './harness/mnist/data', 'Directory to store/load MNIST dataset') -flags.DEFINE_boolean('no_cuda', False, 'Disable CUDA even if available') -flags.DEFINE_integer('seed', RNG_SEED, 'Random seed for reproducibility') - -flags.DEFINE_boolean('export_test_data', False, 'Export test dataset to file and exit') -flags.DEFINE_string('test_data_output', 'mnist_test.txt', 'Output file for exported test data') -flags.DEFINE_integer('num_samples', -1, 'Number of samples to export (-1 for all samples)') - -flags.DEFINE_boolean('predict', False, 'Run prediction on pixels file and exit') -flags.DEFINE_string('pixels_file', '', 'Path to file containing pixel data for prediction') -flags.DEFINE_string('predictions_file', 'predictions.txt', 'Output file for predictions') +# Define command line flags safely to allow importing this module from other apps +try: + flags.DEFINE_string('model_path', MODEL_PATH, 'Path to save/load the model') + flags.DEFINE_integer('batch_size', BATCH_SIZE, 'Batch size for training and evaluation') + flags.DEFINE_float('learning_rate', LEARNING_RATE, 'Learning rate for optimizer') + flags.DEFINE_integer('epochs', EPOCHS, 'Number of training epochs') + flags.DEFINE_string('data_dir', './harness/mnist/data', 'Directory to store/load MNIST dataset') + flags.DEFINE_boolean('no_cuda', False, 'Disable CUDA even if available') + flags.DEFINE_integer('seed', RNG_SEED, 'Random seed for reproducibility') + + flags.DEFINE_boolean('export_test_data', False, 'Export test dataset to file and exit') + flags.DEFINE_string('test_data_output', 'mnist_test.txt', 'Output file for exported test data') + flags.DEFINE_integer('num_samples', -1, 'Number of samples to export (-1 for all samples)') + + flags.DEFINE_boolean('predict', False, 'Run prediction on pixels file and exit') + flags.DEFINE_string('pixels_file', '', 'Path to file containing pixel data for prediction') + flags.DEFINE_string('predictions_file', 'predictions.txt', 'Output file for predictions') +except flags.DuplicateFlagError: + pass # Ensure reproducibility torch.manual_seed(RNG_SEED) diff --git a/harness/utils.py b/harness/utils.py index 3a4eb35..ee38529 100644 --- a/harness/utils.py +++ b/harness/utils.py @@ -220,7 +220,7 @@ def run_exe_or_python(base, file_name, *args, check=True): exe = base / "build" / file_name if py.exists(): - cmd = ["python3", py, *args] + cmd = [sys.executable, py, *args] elif exe.exists(): cmd = [exe, *args] else: