diff --git a/VidActRecTrain.py b/VidActRecTrain.py index 90f72b6..ead4c44 100644 --- a/VidActRecTrain.py +++ b/VidActRecTrain.py @@ -52,6 +52,7 @@ from utility.eval_utility import WorstExamples from utility.model_utility import restoreModelAndState from utility.saliency_utils import plot_gradcam_for_multichannel_input +from utility.saliency_utils import plot_saliency_map # ---------------------- Argument Parser ---------------------- # Added: Set up the command-line arguments as per the provided instructions. @@ -787,3 +788,57 @@ width=image_size[-1], map_percent=args.map_percent, ) + +# Generate saliency maps +logging.info("Generating saliency maps...") + +# Create a separate dataloader for saliency maps with smaller batch size +saliency_dataset = dataset_utility.makeDataset( +[args.evaluate], decode_strs, shuffle=False, shardshuffle=False +) +saliency_dataloader = torch.utils.data.DataLoader( +saliency_dataset, num_workers=0, batch_size=200 +) + +try: + # Process a few batches for saliency maps + save_folder = os.path.basename(args.evaluate).replace('.tar', '') + batch_count = 0 + max_batches = 30 # Limit to avoid too many files + + for batch_idx, batch in enumerate(saliency_dataloader): + if batch_count >= max_batches: + break + + # Extract input tensor and labels - FIXED VERSION + if in_frames == 1: + input_tensor = batch[0].unsqueeze(1) # Add channel dimension + labels = batch[1] + else: + frames = [] + for i in range(in_frames): + frame = batch[i].unsqueeze(1) # Add channel dimension + frames.append(frame) + input_tensor = torch.cat(frames, dim=1) # Concatenate along channel dim + labels = batch[in_frames] + + # Generate saliency maps + plot_saliency_map( + model=net, + save_folder=save_folder, + input_tensor=input_tensor.float(), + target_class=(labels - label_offset).tolist(), + batch_num=batch_idx, + model_name=args.modeltype, + process_all_samples=True, + map_percent=args.map_percent, + power_scale=args.power_scale + ) + + batch_count += 1 + logging.info(f"Generated saliency maps for batch {batch_count}/{max_batches}") + + logging.info(f"Completed saliency map generation for {batch_count} batches") + +except Exception as e: + logging.error(f"Failed to generate saliency maps: {e}") \ No newline at end of file