Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions VidActRecTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from utility.eval_utility import WorstExamples
from utility.model_utility import restoreModelAndState
from utility.saliency_utils import plot_gradcam_for_multichannel_input
from utility.saliency_utils import plot_saliency_map

# ---------------------- Argument Parser ----------------------
# Added: Set up the command-line arguments as per the provided instructions.
Expand Down Expand Up @@ -787,3 +788,57 @@
width=image_size[-1],
map_percent=args.map_percent,
)

# Generate saliency maps
logging.info("Generating saliency maps...")

# Create a separate dataloader for saliency maps with smaller batch size
saliency_dataset = dataset_utility.makeDataset(
[args.evaluate], decode_strs, shuffle=False, shardshuffle=False
)
saliency_dataloader = torch.utils.data.DataLoader(
saliency_dataset, num_workers=0, batch_size=200
)

try:
# Process a few batches for saliency maps
save_folder = os.path.basename(args.evaluate).replace('.tar', '')
batch_count = 0
max_batches = 30 # Limit to avoid too many files

for batch_idx, batch in enumerate(saliency_dataloader):
if batch_count >= max_batches:
break

# Extract input tensor and labels - FIXED VERSION
if in_frames == 1:
input_tensor = batch[0].unsqueeze(1) # Add channel dimension
labels = batch[1]
else:
frames = []
for i in range(in_frames):
frame = batch[i].unsqueeze(1) # Add channel dimension
frames.append(frame)
input_tensor = torch.cat(frames, dim=1) # Concatenate along channel dim
labels = batch[in_frames]

# Generate saliency maps
plot_saliency_map(
model=net,
save_folder=save_folder,
input_tensor=input_tensor.float(),
target_class=(labels - label_offset).tolist(),
batch_num=batch_idx,
model_name=args.modeltype,
process_all_samples=True,
map_percent=args.map_percent,
power_scale=args.power_scale
)

batch_count += 1
logging.info(f"Generated saliency maps for batch {batch_count}/{max_batches}")

logging.info(f"Completed saliency map generation for {batch_count} batches")

except Exception as e:
logging.error(f"Failed to generate saliency maps: {e}")
Loading