diff --git a/multic/segmentationschool/Codes/IterativeTraining_1X.py b/multic/segmentationschool/Codes/IterativeTraining_1X.py index 9c98890..d7fb5fb 100644 --- a/multic/segmentationschool/Codes/IterativeTraining_1X.py +++ b/multic/segmentationschool/Codes/IterativeTraining_1X.py @@ -1,677 +1,405 @@ -import os, sys, cv2, time, random, warnings, multiprocessing#json,# detectron2 +import os,cv2, time, random, multiprocessing,copy +from skimage.color import rgb2hsv,hsv2rgb,rgb2lab,lab2rgb import numpy as np -import matplotlib.pyplot as plt -import lxml.etree as ET -from matplotlib import path -from skimage.transform import resize -from skimage.io import imread, imsave -import glob -from .getWsi import getWsi - -from .xml_to_mask2 import get_supervision_boxes, regions_in_mask_dots, get_vertex_points_dots, masks_from_points, restart_line -from joblib import Parallel, delayed -from shutil import move +from tiffslide import TiffSlide +from .xml_to_mask_minmax import xml_to_mask # from generateTrainSet import generateDatalists -#from subprocess import call -#from .get_choppable_regions import get_choppable_regions -from PIL import Image - +import logging from detectron2.utils.logger import setup_logger +from skimage import exposure + setup_logger() from detectron2 import model_zoo -from detectron2.engine import DefaultPredictor,DefaultTrainer +from detectron2.engine import DefaultTrainer from detectron2.config import get_cfg -from detectron2.utils.visualizer import Visualizer#,ColorMode -from detectron2.data import MetadataCatalog, DatasetCatalog -#from detectron2.structures import BoxMode -from .get_dataset_list import HAIL2Detectron, samples_from_json, samples_from_json_mini -#from detectron2.checkpoint import DetectionCheckpointer -#from detectron2.modeling import build_model - -""" - -Code for - cutting / augmenting / training CNN - -This uses WSI and XML files to train 2 neural networks for semantic segmentation - of histopath tissue via human in the loop training - -""" - +# from detectron2.data import MetadataCatalog, DatasetCatalog +from detectron2.data import detection_utils as utils +import detectron2.data.transforms as T +from detectron2.structures import BoxMode +from detectron2.data import (DatasetCatalog, + MetadataCatalog, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.config import configurable +from typing import List, Optional, Union +import torch +from detectron2.evaluation import COCOEvaluator +#from .engine.hooks import LossEvalHook +# sys.append("..") +from .wsi_loader_utils import train_samples_from_WSI, get_slide_data, get_random_chops +from imgaug import augmenters as iaa +from .engine.hooks import LossEvalHook + +global seq +seq = iaa.Sequential([ + iaa.Sometimes(0.5,iaa.OneOf([ + iaa.AddElementwise((-15,15),per_channel=0.5), + iaa.ImpulseNoise(0.05),iaa.CoarseDropout(0.02, size_percent=0.5)])), + iaa.Sometimes(0.5,iaa.OneOf([iaa.GaussianBlur(sigma=(0, 3.0)), + iaa.Sharpen(alpha=(0.0, 1.0), lightness=(0.75, 2.0))])) +]) #Record start time totalStart=time.time() def IterateTraining(args): - ## calculate low resolution block params - downsampleLR = int(args.downsampleRateLR**.5) #down sample for each dimension - region_sizeLR = int(args.boxSizeLR*(downsampleLR)) #Region size before downsampling - stepLR = int(region_sizeLR*(1-args.overlap_percentLR)) #Step size before downsampling - ## calculate low resolution block params - downsampleHR = int(args.downsampleRateHR**.5) #down sample for each dimension - region_sizeHR = int(args.boxSizeHR*(downsampleHR)) #Region size before downsampling - stepHR = int(region_sizeHR*(1-args.overlap_percentHR)) #Step size before downsampling - - global classNum_HR,classEnumLR,classEnumHR + + region_size = int(args.boxSize) #Region size before downsampling + dirs = {'imExt': '.jpeg'} dirs['basedir'] = args.base_dir dirs['maskExt'] = '.png' - dirs['modeldir'] = '/MODELS/' - dirs['tempdirLR'] = '/TempLR/' - dirs['tempdirHR'] = '/TempHR/' - dirs['pretraindir'] = '/Deeplab_network/' - dirs['training_data_dir'] = '/TRAINING_data/' - dirs['model_init'] = 'deeplab_resnet.ckpt' - dirs['project']= '/' + args.project - dirs['data_dir_HR'] = args.base_dir +'/' + args.project + '/Permanent/HR/' - dirs['data_dir_LR'] = args.base_dir +'/' +args.project + '/Permanent/LR/' - - - ##All folders created, initiate WSI loading by human - #raw_input('Please place WSIs in ') - - ##Check iteration session - - currentmodels=os.listdir(dirs['basedir'] + dirs['project'] + dirs['modeldir']) - print('Handcoded iteration') - # currentAnnotationIteration=check_model_generation(dirs) - currentAnnotationIteration=2 - print('Current training session is: ' + str(currentAnnotationIteration)) - - ##Create objects for storing class distributions - annotatedXMLs=glob.glob(dirs['basedir'] + dirs['project'] + dirs['training_data_dir'] + str(currentAnnotationIteration) + '/*.xml') - classes=[] + dirs['training_data_dir'] = args.training_data_dir + dirs['val_data_dir'] = args.training_data_dir - if args.classNum == 0: - for xml in annotatedXMLs: - classes.append(get_num_classes(xml)) - classNum_HR = max(classes) - else: - classNum_LR = args.classNum - if args.classNum_HR != 0: - classNum_HR = args.classNum_HR - else: - classNum_HR = classNum_LR - - classNum_HR=args.classNum - - ##for all WSIs in the initiating directory: - if args.chop_data == 'True': - print('Chopping') - - start=time.time() - size_data=[] - - for xmlID in annotatedXMLs: - - #Get unique name of WSI - fileID=xmlID.split('/')[-1].split('.xml')[0] - print('-----------------'+fileID+'----------------') - #create memory addresses for wsi files - for ext in [args.wsi_ext]: - wsiID=dirs['basedir'] + dirs['project']+ dirs['training_data_dir'] + str(currentAnnotationIteration) +'/'+ fileID + ext - - #Ensure annotations exist - if os.path.isfile(wsiID)==True: - break - #Load openslide information about WSI - if ext != '.tif': - slide=getWsi(wsiID) - #WSI level 0 dimensions (largest size) - dim_x,dim_y=slide.dimensions - else: - im = Image.open(wsiID) - dim_x, dim_y=im.size - location=[0,0] - size=[dim_x,dim_y] - tree = ET.parse(xmlID) - root = tree.getroot() - box_supervision_layers=['8'] - # calculate region bounds - global_bounds = {'x_min' : location[0], 'y_min' : location[1], 'x_max' : location[0] + size[0], 'y_max' : location[1] + size[1]} - local_bounds = get_supervision_boxes(root,box_supervision_layers) - num_cores = multiprocessing.cpu_count() - Parallel(n_jobs=num_cores)(delayed(chop_suey_bounds)(args=args,wsiID=wsiID, - dirs=dirs,lb=lb,xmlID=xmlID,box_supervision_layers=box_supervision_layers) for lb in tqdm(local_bounds)) - # for lb in tqdm(local_bounds): - - # size_data.extend(image_sizes) + print('Handcoded iteration') - ''' - wsi_mask=xml_to_mask(xmlID, [0,0], [dim_x,dim_y]) - #Enumerate cpu core count - num_cores = multiprocessing.cpu_count() + #os.environ["CUDA_VISIBLE_DEVICES"]=gpu + #os.system('export CUDA_VISIBLE_DEVICES=$(nvidia-smi --query-gpu=memory.free,index --format=csv,nounits,noheader | sort -nr | head -1 | awk "{ print $NF }")') + os.environ["CUDA_VISIBLE_DEVICES"] ='2,3' + os.environ["CUDA_LAUNCH_BLOCKING"] ='1' - #Generate iterators for parallel chopping of WSIs in high resolution - #index_yHR=range(30240,dim_y-stepHR,stepHR) - #index_xHR=range(840,dim_x-stepHR,stepHR) - index_yHR=range(0,dim_y,stepHR) - index_xHR=range(0,dim_x,stepHR) - index_yHR[-1]=dim_y-stepHR - index_xHR[-1]=dim_x-stepHR - #Create memory address for chopped images high resolution - outdirHR=dirs['basedir'] + dirs['project'] + dirs['tempdirHR'] - #Perform high resolution chopping in parallel and return the number of - #images in each of the labeled classes - chop_regions=get_choppable_regions(wsi=wsiID, - index_x=index_xHR,index_y=index_yHR,boxSize=region_sizeHR,white_percent=args.white_percent) + organType='kidney' + print('Organ meta being set to... '+ organType) + if organType=='liver': + classnames=['Background','BD','A'] + isthing=[0,1,1] + xml_color = [[0,255,0], [0,255,255], [0,0,255]] + tc=['BD','AT'] + sc=['Ob','B'] + elif organType =='kidney': + classnames=['interstitium','medulla','glomerulus','sclerotic glomerulus','tubule','arterial tree'] + classes={} + isthing=[0,0,1,1,1,1] + xml_color = [[0,255,0], [0,255,255], [255,255,0],[0,0,255], [255,0,0], [0,128,255]] + tc=['G','SG','T','A'] + sc=['Ob','I','M','B'] + else: + print('Provided organType not in supported types: kidney, liver') - Parallel(n_jobs=num_cores)(delayed(return_region)(args=args, - wsi_mask=wsi_mask, wsiID=wsiID, - fileID=fileID, yStart=j, xStart=i, idxy=idxy, - idxx=idxx, downsampleRate=args.downsampleRateHR, - outdirT=outdirHR, region_size=region_sizeHR, - dirs=dirs, chop_regions=chop_regions,classNum_HR=classNum_HR) for idxx,i in enumerate(index_xHR) for idxy,j in enumerate(index_yHR)) - #for idxx,i in enumerate(index_xHR): - # for idxy,j in enumerate(index_yHR): - # if chop_regions[idxy,idxx] != 0: - # return_region(args=args,xmlID=xmlID, wsiID=wsiID, fileID=fileID, yStart=j, xStart=i,idxy=idxy, idxx=idxx, - # downsampleRate=args.downsampleRateHR,outdirT=outdirHR, region_size=region_sizeHR, dirs=dirs, - # chop_regions=chop_regions,classNum_HR=classNum_HR) - # else: - # print('pass') - # exit() - print('Time for WSI chopping: ' + str(time.time()-start)) + classNum=len(tc)+len(sc)-1 + print('Number classes: '+ str(classNum)) + classes={} - classEnumHR=np.ones([classNum_HR,1])*classNum_HR + for idx,c in enumerate(classnames): + classes[idx]={'isthing':isthing[idx],'color':xml_color[idx]} - ##High resolution augmentation - #Enumerate high resolution class distribution - classDistHR=np.zeros(len(classEnumHR)) - for idx,value in enumerate(classEnumHR): - classDistHR[idx]=value/sum(classEnumHR) - print(classDistHR) - #Define number of augmentations per class - moveimages(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/regions/', dirs['basedir']+dirs['project'] + '/Permanent/HR/regions/') - moveimages(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/masks/',dirs['basedir']+dirs['project'] + '/Permanent/HR/masks/') + num_images=args.batch_size*args.train_steps + # slide_idxs=train_dset.get_random_slide_idx(num_images) + usable_slides=get_slide_data(args, wsi_directory = dirs['training_data_dir']) + print('Number of slides:', len(usable_slides)) + usable_idx=range(0,len(usable_slides)) + slide_idxs=random.choices(usable_idx,k=num_images) + image_coordinates=get_random_chops(slide_idxs,usable_slides,region_size) + DatasetCatalog.register("my_dataset", lambda:train_samples_from_WSI(args,image_coordinates)) + MetadataCatalog.get("my_dataset").set(thing_classes=tc) + MetadataCatalog.get("my_dataset").set(stuff_classes=sc) - #Total time - print('Time for high resolution augmenting: ' + str((time.time()-totalStart)/60) + ' minutes.') - ''' - - # with open('sizes.csv','w',newline='') as myfile: - # wr=csv.writer(myfile,quoting=csv.QUOTE_ALL) - # wr.writerow(size_data) - # pretrain_HR=get_pretrain(currentAnnotationIteration,'/HR/',dirs) - - modeldir_HR = dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration+1) + '/HR/' - - - ##### HIGH REZ ARGS ##### - dirs['outDirAIHR']=dirs['basedir']+'/'+dirs['project'] + '/Permanent/HR/regions/' - dirs['outDirAMHR']=dirs['basedir']+'/'+dirs['project'] + '/Permanent/HR/masks/' - - - numImagesHR=len(glob.glob(dirs['outDirAIHR'] + '*' + dirs['imExt'])) - - numStepsHR=(args.epoch_HR*numImagesHR)/ args.CNNbatch_sizeHR - - - #----------------------------------------------------------------------------------------- - # os.environ["CUDA_VISIBLE_DEVICES"]='0' - os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) - # img_dir='/hdd/bg/Detectron2/chop_detectron/Permanent/HR' + usable_slides_val=get_slide_data(args, wsi_directory=dirs['val_data_dir']) - img_dir=dirs['outDirAIHR'] - classnames=['Background','BD','A'] - isthing=[0,1,1] - xml_color = [[0,255,0], [0,255,255], [0,0,255]] + usable_idx_val=range(0,len(usable_slides_val)) + slide_idxs_val=random.choices(usable_idx_val,k=int(args.batch_size*args.train_steps/100)) + image_coordinates_val=get_random_chops(slide_idxs_val,usable_slides_val,region_size) - rand_sample=True - json_file=img_dir+'/detectron_train.json' - HAIL2Detectron(img_dir,rand_sample,json_file,classnames,isthing,xml_color) - tc=['BD','AT'] - sc=['I','B'] - #### From json - DatasetCatalog.register("my_dataset", lambda:samples_from_json(json_file,rand_sample)) - MetadataCatalog.get("my_dataset").set(thing_classes=tc) - MetadataCatalog.get("my_dataset").set(stuff_classes=sc) + - seg_metadata=MetadataCatalog.get("my_dataset") - - - # new_list = DatasetCatalog.get("my_dataset") - # print(len(new_list)) - # for d in random.sample(new_list, 100): - # - # img = cv2.imread(d["file_name"]) - # visualizer = Visualizer(img[:, :, ::-1],metadata=seg_metadata, scale=0.5) - # out = visualizer.draw_dataset_dict(d) - # cv2.namedWindow("output", cv2.WINDOW_NORMAL) - # cv2.imshow("output",out.get_image()[:, :, ::-1]) - # cv2.waitKey(0) # waits until a key is pressed - # cv2.destroyAllWindows() - # exit() + _ = os.system("printf '\nTraining starts...\n'") cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml")) cfg.DATASETS.TRAIN = ("my_dataset") - cfg.DATASETS.TEST = () + + cfg.TEST.EVAL_PERIOD=args.eval_period + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.50 num_cores = multiprocessing.cpu_count() - cfg.DATALOADER.NUM_WORKERS = num_cores-3 - # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml") # Let training initialize from model zoo + cfg.DATALOADER.NUM_WORKERS = args.num_workers - cfg.MODEL.WEIGHTS = os.path.join('/hdd/bg/Detectron2/HAIL_Detectron2/liver/MODELS/0/HR', "model_final.pth") + if args.init_modelfile: + cfg.MODEL.WEIGHTS = args.init_modelfile + else: + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml") # Let training initialize from model zoo - cfg.SOLVER.IMS_PER_BATCH = 10 + cfg.SOLVER.IMS_PER_BATCH = args.batch_size - # cfg.SOLVER.BASE_LR = 0.02 # pick a good LR - # cfg.SOLVER.LR_policy='steps_with_lrs' - # cfg.SOLVER.MAX_ITER = 50000 - # cfg.SOLVER.STEPS = [30000,40000] - # # cfg.SOLVER.STEPS = [] - # cfg.SOLVER.LRS = [0.002,0.0002] - cfg.SOLVER.BASE_LR = 0.002 # pick a good LR cfg.SOLVER.LR_policy='steps_with_lrs' - cfg.SOLVER.MAX_ITER = 200000 - cfg.SOLVER.STEPS = [150000,180000] - # cfg.SOLVER.STEPS = [] - cfg.SOLVER.LRS = [0.0002,0.00002] - - # cfg.INPUT.CROP.ENABLED = True - # cfg.INPUT.CROP.TYPE='absolute' - # cfg.INPUT.CROP.SIZE=[100,100] - cfg.MODEL.BACKBONE.FREEZE_AT = 0 - # cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[4],[8],[16], [32], [64], [64], [64]] - # cfg.MODEL.RPN.IN_FEATURES = ['p2', 'p2', 'p2', 'p3','p4','p5','p6'] - cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.33, 0.5, 1.0, 2.0, 3.0]] + cfg.SOLVER.MAX_ITER = args.train_steps + cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR + cfg.SOLVER.LRS = [0.000025,0.0000025] + cfg.SOLVER.STEPS = [70000,90000] + cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32],[64],[128], [256], [512], [1024]] + cfg.MODEL.RPN.IN_FEATURES = ['p2', 'p3', 'p4', 'p5','p6','p6'] + cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[.1,.2,0.33, 0.5, 1.0, 2.0, 3.0,5,10]] cfg.MODEL.ANCHOR_GENERATOR.ANGLES=[-90,-60,-30,0,30,60,90] - - cfg.MODEL.RPN.POSITIVE_FRACTION = 0.75 - cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(tc) cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES =len(sc) - - - # cfg.INPUT.CROP.ENABLED = True - # cfg.INPUT.CROP.TYPE='absolute' - # cfg.INPUT.CROP.SIZE=[64,64] - - cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256 # faster, and good enough for this toy dataset (default: 512) - # cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets) - + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64 cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS=False - cfg.INPUT.MIN_SIZE_TRAIN=0 - # cfg.INPUT.MAX_SIZE_TRAIN=500 - - # exit() - os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) - with open(cfg.OUTPUT_DIR+"/config_record.yaml", "w") as f: - f.write(cfg.dump()) # save config to file - trainer = DefaultTrainer(cfg) - trainer.resume_or_load(resume=False) - trainer.train() - - cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.01 # set a custom testing threshold - cfg.TEST.DETECTIONS_PER_IMAGE = 500 - # - # cfg.INPUT.MIN_SIZE_TRAIN=64 - # cfg.INPUT.MAX_SIZE_TRAIN=4000 - cfg.INPUT.MIN_SIZE_TEST=64 - cfg.INPUT.MAX_SIZE_TEST=500 - - - predict_samples=100 - predictor = DefaultPredictor(cfg) - - dataset_dicts = samples_from_json_mini(json_file,predict_samples) - iter=0 - if not os.path.exists(os.getcwd()+'/network_predictions/'): - os.mkdir(os.getcwd()+'/network_predictions/') - for d in random.sample(dataset_dicts, predict_samples): - # print(d["file_name"]) - # imclass=d["file_name"].split('/')[-1].split('_')[-5].split(' ')[-1] - # if imclass in ["TRI","HE"]: - im = cv2.imread(d["file_name"]) - panoptic_seg, segments_info = predictor(im)["panoptic_seg"] # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format - # print(segments_info) - # plt.imshow(panoptic_seg.to("cpu")) - # plt.show() - v = Visualizer(im[:, :, ::-1], seg_metadata, scale=1.2) - v = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info) - # panoptic segmentation result - # plt.ion() - plt.subplot(121) - plt.imshow(im[:, :, ::-1]) - plt.subplot(122) - plt.imshow(v.get_image()) - plt.savefig(f"./network_predictions/input_{iter}.jpg",dpi=300) - plt.show() - # plt.ioff() - - - # v = Visualizer(im[:, :, ::-1], - # metadata=seg_metadata, - # scale=0.5, - # ) - # out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"),segments_info) - - # imsave('./network_predictions/pred'+str(iter)+'.png',np.hstack((im,v.get_image()))) - iter=iter+1 - # cv2.imshow('',out.get_image()[:, :, ::-1]) - # cv2.waitKey(0) # waits until a key is pressed - # cv2.destroyAllWindows() - #----------------------------------------------------------------------------------------- - - finish_model_generation(dirs,currentAnnotationIteration) - - print('\n\n\033[92;5mPlease place new wsi file(s) in: \n\t' + dirs['basedir'] + dirs['project']+ dirs['training_data_dir'] + str(currentAnnotationIteration+1)) - print('\nthen run [--option predict]\033[0m\n') - - - - -def moveimages(startfolder,endfolder): - filelist=glob.glob(startfolder + '*') - for file in filelist: - fileID=file.split('/')[-1] - move(file,endfolder + fileID) - - -def check_model_generation(dirs): - modelsCurrent=os.listdir(dirs['basedir'] + dirs['project'] + dirs['modeldir']) - gens=map(int,modelsCurrent) - modelOrder=np.sort(gens)[::-1] - - for idx in modelOrder: - #modelsChkptsLR=glob.glob(dirs['basedir'] + dirs['project'] + dirs['modeldir']+str(modelsCurrent[idx]) + '/LR/*.ckpt*') - modelsChkptsHR=glob.glob(dirs['basedir'] + dirs['project'] + dirs['modeldir']+ str(idx) +'/HR/*.ckpt*') - if modelsChkptsHR == []: - continue - else: - return idx - break - -def finish_model_generation(dirs,currentAnnotationIteration): - make_folder(dirs['basedir'] + dirs['project'] + dirs['training_data_dir'] + str(currentAnnotationIteration + 1)) + cfg.INPUT.MIN_SIZE_TRAIN=args.boxSize + cfg.INPUT.MAX_SIZE_TRAIN=args.boxSize + + cfg.OUTPUT_DIR = args.base_dir+"/output" -def get_pretrain(currentAnnotationIteration,res,dirs): - - if currentAnnotationIteration==0: - pretrain_file = glob.glob(dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + '*') - pretrain_file=pretrain_file[0].split('.')[0] + '.' + pretrain_file[0].split('.')[1] - - else: - pretrains=glob.glob(dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + 'model*') - maxmodel=0 - for modelfiles in pretrains: - modelID=modelfiles.split('.')[-2].split('-')[1] - if int(modelID)>maxmodel: - maxmodel=int(modelID) - pretrain_file=dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + 'model.ckpt-' + str(maxmodel) - return pretrain_file -def restart_line(): # for printing chopped image labels in command line - sys.stdout.write('\r') - sys.stdout.flush() + def real_data(args,image_coordinates_val): -def file_len(fname): # get txt file length (number of lines) - with open(fname) as f: - for i, l in enumerate(f): - pass - return i + 1 -def make_folder(directory): - if not os.path.exists(directory): - os.makedirs(directory) # make directory if it does not exit already # make new directory # Check if folder exists, if not make it + all_list=[] + for one in train_samples_from_WSI(args,image_coordinates_val): + dataset_dict = one + c=dataset_dict['coordinates'] + h=dataset_dict['height'] + w=dataset_dict['width'] + maskData=xml_to_mask(dataset_dict['xml_loc'], c, [h,w]) + dataset_dict['annotations'] = mask2polygons(maskData) + all_list.append(dataset_dict) -def make_all_folders(dirs): + return all_list + DatasetCatalog.register("my_dataset_val", lambda:real_data(args,image_coordinates_val)) + MetadataCatalog.get("my_dataset_val").set(thing_classes=tc) + MetadataCatalog.get("my_dataset_val").set(stuff_classes=sc) + + cfg.DATASETS.TEST = ("my_dataset_val",) - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/regions') - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/masks') - - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/Augment' +'/regions') - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/Augment' +'/masks') - - make_folder(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/regions') - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirHR'] + '/masks') - - make_folder(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/Augment' +'/regions') - make_folder(dirs['basedir']+dirs['project']+ dirs['tempdirHR'] + '/Augment' +'/masks') - - make_folder(dirs['basedir'] +dirs['project']+ dirs['modeldir']) - make_folder(dirs['basedir'] +dirs['project']+ dirs['training_data_dir']) - - - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/LR/'+ 'regions/') - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/LR/'+ 'masks/') - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/HR/'+ 'regions/') - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/HR/'+ 'masks/') - - make_folder(dirs['basedir'] +dirs['project']+ dirs['training_data_dir']) - - make_folder(dirs['basedir'] + '/Codes/Deeplab_network/datasetLR') - make_folder(dirs['basedir'] + '/Codes/Deeplab_network/datasetHR') + + + #os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) + with open(cfg.OUTPUT_DIR+"/config_record.yaml", "w") as f: + f.write(cfg.dump()) # save config to file -def return_region(args, wsi_mask, wsiID, fileID, yStart, xStart, idxy, idxx, downsampleRate, outdirT, region_size, dirs, chop_regions,classNum_HR): # perform cutting in parallel - sys.stdout.write(' <'+str(xStart)+'/'+ str(yStart)+'/'+str(chop_regions[idxy,idxx] != 0)+ '> ') - sys.stdout.flush() - restart_line() - if chop_regions[idxy,idxx] != 0: + trainer = Trainer(cfg) + print('check and see') + trainer.resume_or_load(resume=False) + trainer.train() - uniqID=fileID + str(yStart) + str(xStart) - if wsiID.split('.')[-1] != 'tif': - slide=getWsi(wsiID) - Im=np.array(slide.read_region((xStart,yStart),0,(region_size,region_size))) - Im=Im[:,:,:3] - else: - yEnd = yStart + region_size - xEnd = xStart + region_size - Im = np.zeros([region_size,region_size,3], dtype=np.uint8) - Im_ = imread(wsiID)[yStart:yEnd, xStart:xEnd, :3] - Im[0:Im_.shape[0], 0:Im_.shape[1], :] = Im_ - - mask_annotation=wsi_mask[yStart:yStart+region_size,xStart:xStart+region_size] - - o1,o2=mask_annotation.shape - if o1 !=region_size: - mask_annotation=np.pad(mask_annotation,((0,region_size-o1),(0,0)),mode='constant') - if o2 !=region_size: - mask_annotation=np.pad(mask_annotation,((0,0),(0,region_size-o2)),mode='constant') - - ''' - if 4 in np.unique(mask_annotation): - plt.subplot(121) - plt.imshow(mask_annotation*20) - plt.subplot(122) - plt.imshow(Im) - pt=[xStart,yStart] - plt.title(pt) - plt.show() - ''' - if downsampleRate !=1: - c=(Im.shape) - s1=int(c[0]/(downsampleRate**.5)) - s2=int(c[1]/(downsampleRate**.5)) - Im=resize(Im,(s1,s2),mode='reflect') - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - imsave(outdirT + '/regions/' + uniqID + dirs['imExt'],Im) - imsave(outdirT + '/masks/' + uniqID +dirs['maskExt'],mask_annotation) - - -def regions_in_mask(root, bounds, verbose=1): - # find regions to save - IDs_reg = [] - IDs_points = [] - - for Annotation in root.findall("./Annotation"): # for all annotations - annotationID = Annotation.attrib['Id'] - annotationType = Annotation.attrib['Type'] - - # print(Annotation.findall(./)) - if annotationType =='9': - for element in Annotation.iter('InputAnnotationId'): - pointAnnotationID=element.text - - for Region in Annotation.findall("./*/Region"): # iterate on all region - - for Vertex in Region.findall("./*/Vertex"): # iterate on all vertex in region - # get points - x_point = np.int32(np.float64(Vertex.attrib['X'])) - y_point = np.int32(np.float64(Vertex.attrib['Y'])) - # test if points are in bounds - if bounds['x_min'] <= x_point <= bounds['x_max'] and bounds['y_min'] <= y_point <= bounds['y_max']: # test points in region bounds - # save region Id - IDs_points.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID,'pointAnnotationID':pointAnnotationID}) - break - elif annotationType=='4': - - for Region in Annotation.findall("./*/Region"): # iterate on all region - - for Vertex in Region.findall("./*/Vertex"): # iterate on all vertex in region - # get points - x_point = np.int32(np.float64(Vertex.attrib['X'])) - y_point = np.int32(np.float64(Vertex.attrib['Y'])) - # test if points are in bounds - if bounds['x_min'] <= x_point <= bounds['x_max'] and bounds['y_min'] <= y_point <= bounds['y_max']: # test points in region bounds - # save region Id - IDs_reg.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID}) - break - return IDs_reg,IDs_points - - -def get_vertex_points(root, IDs_reg,IDs_points, maskModes,excludedIDs,negativeIDs=None): - Regions = [] - Points = [] - - for ID in IDs_reg: - Vertices = [] - if ID['annotationID'] not in excludedIDs: - for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"): - Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))]) - Regions.append({'Vertices':np.array(Vertices),'annotationID':ID['annotationID']}) - - for ID in IDs_points: - Vertices = [] - for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"): - Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))]) - Points.append({'Vertices':np.array(Vertices),'pointAnnotationID':ID['pointAnnotationID']}) - if 'falsepositive' or 'negative' in maskModes: - assert negativeIDs is not None,'Negatively annotated classes must be provided for negative/falsepositive mask mode' - assert 'falsepositive' and 'negative' not in maskModes, 'Negative and false positive mask modes cannot both be true' - - useableRegions=[] - if 'positive' in maskModes: - for Region in Regions: - regionPath=path.Path(Region['Vertices']) - for Point in Points: - if Region['annotationID'] not in negativeIDs: - if regionPath.contains_point(Point['Vertices'][0]): - Region['pointAnnotationID']=Point['pointAnnotationID'] - useableRegions.append(Region) - - if 'negative' in maskModes: - - for Region in Regions: - regionPath=path.Path(Region['Vertices']) - if Region['annotationID'] in negativeIDs: - if not any([regionPath.contains_point(Point['Vertices'][0]) for Point in Points]): - Region['pointAnnotationID']=Region['annotationID'] - useableRegions.append(Region) - if 'falsepositive' in maskModes: - - for Region in Regions: - regionPath=path.Path(Region['Vertices']) - if Region['annotationID'] in negativeIDs: - if not any([regionPath.contains_point(Point['Vertices'][0]) for Point in Points]): - Region['pointAnnotationID']=0 - useableRegions.append(Region) - - return useableRegions -def chop_suey_bounds(lb,xmlID,box_supervision_layers,wsiID,dirs,args): - tree = ET.parse(xmlID) - root = tree.getroot() - lbVerts=np.array(lb['BoxVerts']) - xMin=min(lbVerts[:,0]) - xMax=max(lbVerts[:,0]) - yMin=min(lbVerts[:,1]) - yMax=max(lbVerts[:,1]) - - # test=np.array(slide.read_region((xMin,yMin),0,(xMax-xMin,yMax-yMin)))[:,:,:3] - - local_bound = {'x_min' : xMin, 'y_min' : yMin, 'x_max' : xMax, 'y_max' : yMax} - IDs_reg,IDs_points = regions_in_mask_dots(root=root, bounds=local_bound,box_layers=box_supervision_layers) - - # find regions in bounds - negativeIDs=['4'] - excludedIDs=['1'] - falsepositiveIDs=['4'] - usableRegions= get_vertex_points_dots(root=root, IDs_reg=IDs_reg,IDs_points=IDs_points,excludedIDs=excludedIDs,maskModes=['falsepositive','positive'],negativeIDs=negativeIDs, - falsepositiveIDs=falsepositiveIDs) - - # image_sizes= - masks_from_points(usableRegions,wsiID,dirs,50,args,[xMin,xMax,yMin,yMax]) -''' -def masks_from_points(root,usableRegions,wsiID,dirs): - pas_img = getWsi(wsiID) - image_sizes=[] - basename=wsiID.split('/')[-1].split('.svs')[0] - - for usableRegion in tqdm(usableRegions): - vertices=usableRegion['Vertices'] - x1=min(vertices[:,0]) - x2=max(vertices[:,0]) - y1=min(vertices[:,1]) - y2=max(vertices[:,1]) - points = np.stack([np.asarray(vertices[:,0]), np.asarray(vertices[:,1])], axis=1) - if (x2-x1)>0 and (y2-y1)>0: - l1=x2-x1 - l2=y2-y1 - xMultiplier=np.ceil((l1)/64) - yMultiplier=np.ceil((l2)/64) - pad1=int(xMultiplier*64-l1) - pad2=int(yMultiplier*64-l2) - - points[:,1] = np.int32(np.round(points[:,1] - y1 )) - points[:,0] = np.int32(np.round(points[:,0] - x1 )) - mask = 2*np.ones([y2-y1,x2-x1], dtype=np.uint8) - if int(usableRegion['pointAnnotationID'])==0: - pass - else: - cv2.fillPoly(mask, [points], int(usableRegion['pointAnnotationID'])-4) - PAS = pas_img.read_region((x1,y1), 0, (x2-x1,y2-y1)) - # print(usableRegion['pointAnnotationID']) - PAS = np.array(PAS)[:,:,0:3] - mask=np.pad( mask,((0,pad2),(0,pad1)),'constant',constant_values=(2,2) ) - PAS=np.pad( PAS,((0,pad2),(0,pad1),(0,0)),'constant',constant_values=(0,0) ) - - image_identifier=basename+'_'.join(['',str(x1),str(y1),str(l1),str(l2)]) - mask_out_name=dirs['basedir']+dirs['project'] + '/Permanent/HR/masks/'+image_identifier+'.png' - image_out_name=mask_out_name.replace('/masks/','/regions/') - # basename + '_' + str(image_identifier) + args.imBoxExt - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - imsave(image_out_name,PAS) - imsave(mask_out_name,mask) - # exit() - # extract image region - # plt.subplot(121) - # plt.imshow(PAS) - # plt.subplot(122) - # plt.imshow(mask) - # plt.show() - # image_sizes.append([x2-x1,y2-y1]) + _ = os.system("printf '\nTraining completed!\n'") + +def mask2polygons(mask): + annotation=[] + presentclasses=np.unique(mask) + offset=-3 + presentclasses=presentclasses[presentclasses>2] + presentclasses=list(presentclasses[presentclasses<7]) + for p in presentclasses: + contours, hierarchy = cv2.findContours(np.array(mask==p, dtype='uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + for contour in contours: + if contour.size>=6: + instance_dict={} + contour_flat=contour.flatten().astype('float').tolist() + xMin=min(contour_flat[::2]) + yMin=min(contour_flat[1::2]) + xMax=max(contour_flat[::2]) + yMax=max(contour_flat[1::2]) + instance_dict['bbox']=[xMin,yMin,xMax,yMax] + instance_dict['bbox_mode']=BoxMode.XYXY_ABS.value + instance_dict['category_id']=p+offset + instance_dict['segmentation']=[contour_flat] + annotation.append(instance_dict) + return annotation + + +class Trainer(DefaultTrainer): + + @classmethod + def build_test_loader(cls, cfg, dataset_name): + return build_detection_test_loader(cfg, dataset_name, mapper=CustomDatasetMapper(cfg, True)) + + @classmethod + def build_train_loader(cls, cfg): + return build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True)) + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + return COCOEvaluator(dataset_name, cfg, True, output_folder) + + def build_hooks(self): + hooks = super().build_hooks() + hooks.insert(-1,LossEvalHook( + self.cfg.TEST.EVAL_PERIOD, + self.model, + build_detection_test_loader( + self.cfg, + self.cfg.DATASETS.TEST[0], + CustomDatasetMapper(self.cfg, True) + ) + )) + return hooks + + +class CustomDatasetMapper: + + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + use_instance_mask: bool = False, + use_keypoint: bool = False, + instance_mask_format: str = "polygon", + keypoint_hflip_indices: Optional[np.ndarray] = None, + precomputed_proposal_topk: Optional[int] = None, + recompute_boxes: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + use_instance_mask: whether to process instance segmentation annotations, if available + use_keypoint: whether to process keypoint annotations if available + instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation + masks into this format. + keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` + precomputed_proposal_topk: if given, will load pre-computed + proposals from dataset_dict and keep the top k proposals for each image. + recompute_boxes: whether to overwrite bounding box annotations + by computing tight bounding boxes from instance mask annotations. + """ + if recompute_boxes: + assert use_instance_mask, "recompute_boxes requires instance masks" + # fmt: off + self.is_train = is_train + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + self.use_instance_mask = use_instance_mask + self.instance_mask_format = instance_mask_format + self.use_keypoint = use_keypoint + self.keypoint_hflip_indices = keypoint_hflip_indices + self.proposal_topk = precomputed_proposal_topk + self.recompute_boxes = recompute_boxes + # fmt: on + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + recompute_boxes = cfg.MODEL.MASK_ON else: - print('Broken region') - return image_sizes -''' + recompute_boxes = False + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "use_instance_mask": cfg.MODEL.MASK_ON, + "instance_mask_format": cfg.INPUT.MASK_FORMAT, + "use_keypoint": cfg.MODEL.KEYPOINT_ON, + "recompute_boxes": recompute_boxes, + } + + if cfg.MODEL.KEYPOINT_ON: + ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + + if cfg.MODEL.LOAD_PROPOSALS: + ret["precomputed_proposal_topk"] = ( + cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN + if is_train + else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST + ) + return ret + + def _transform_annotations(self, dataset_dict, transforms, image_shape): + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.use_instance_mask: + anno.pop("segmentation", None) + if not self.use_keypoint: + anno.pop("keypoints", None) + + annos = [ + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ) + for obj in dataset_dict.pop('annotations') + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format + ) + + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + c=dataset_dict['coordinates'] + h=dataset_dict['height'] + w=dataset_dict['width'] + + slide= TiffSlide(dataset_dict['slide_loc']) + image=np.array(slide.read_region((c[0],c[1]),0,(h,w)))[:,:,:3] + slide.close() + maskData=xml_to_mask(dataset_dict['xml_loc'], c, [h,w]) + + if random.random()>0.5: + hShift=np.random.normal(0,0.05) + lShift=np.random.normal(1,0.025) + # imageblock[im]=randomHSVshift(imageblock[im],hShift,lShift) + image=rgb2hsv(image) + image[:,:,0]=(image[:,:,0]+hShift) + image=hsv2rgb(image) + image=rgb2lab(image) + image[:,:,0]=exposure.adjust_gamma(image[:,:,0],lShift) + image=(lab2rgb(image)*255).astype('uint8') + image = seq(images=[image])[0].squeeze() + + dataset_dict['annotations']=mask2polygons(maskData) + utils.check_image_size(dataset_dict, image) + + sem_seg_gt = maskData + sem_seg_gt[sem_seg_gt>2]=0 + sem_seg_gt[maskData==0] = 3 + sem_seg_gt=np.array(sem_seg_gt).astype('uint8') + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + transforms = self.augmentations(aug_input) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] # h, w + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) + + if "annotations" in dataset_dict: + self._transform_annotations(dataset_dict, transforms, image_shape) + + + return dataset_dict \ No newline at end of file diff --git a/multic/segmentationschool/Codes/engine/hooks.py b/multic/segmentationschool/Codes/engine/hooks.py new file mode 100644 index 0000000..ec1eaf6 --- /dev/null +++ b/multic/segmentationschool/Codes/engine/hooks.py @@ -0,0 +1,69 @@ +import datetime +import logging +import time +import torch +import numpy as np +import detectron2.utils.comm as comm +from detectron2.utils.logger import log_every_n_seconds +from detectron2.engine.hooks import HookBase + +class LossEvalHook(HookBase): + def __init__(self, eval_period, model, data_loader): + self._model = model + self._period = eval_period + self._data_loader = data_loader + + def _do_loss_eval(self): + # Copying inference_on_dataset from evaluator.py + total = len(self._data_loader) + num_warmup = min(5, total - 1) + + start_time = time.perf_counter() + total_compute_time = 0 + losses = [] + for idx, inputs in enumerate(self._data_loader): + if idx == num_warmup: + start_time = time.perf_counter() + total_compute_time = 0 + start_compute_time = time.perf_counter() + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_compute_time += time.perf_counter() - start_compute_time + iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) + seconds_per_img = total_compute_time / iters_after_start + if idx >= num_warmup * 2 or seconds_per_img > 5: + total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start + eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) + log_every_n_seconds( + logging.INFO, + "Loss on Validation done {}/{}. {:.4f} s / img. ETA={}".format( + idx + 1, total, seconds_per_img, str(eta) + ), + n=5, + ) + loss_batch = self._get_loss(inputs) + losses.append(loss_batch) + mean_loss = np.mean(losses) + self.trainer.storage.put_scalar('validation_loss', mean_loss) + comm.synchronize() + + return losses + + def _get_loss(self, data): + # How loss is calculated on train_loop + metrics_dict = self._model(data) + metrics_dict = { + k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v) + for k, v in metrics_dict.items() + } + total_losses_reduced = sum(loss for loss in metrics_dict.values()) + return total_losses_reduced + + + def after_step(self): + next_iter = self.trainer.iter + 1 + is_final = next_iter == self.trainer.max_iter + if is_final or (self._period > 0 and next_iter % self._period == 0): + self._do_loss_eval() + self.trainer.storage.put_scalars(timetest=12) + \ No newline at end of file diff --git a/multic/segmentationschool/Codes/wsi_loader_utils.py b/multic/segmentationschool/Codes/wsi_loader_utils.py index ea87ae6..90a98f8 100644 --- a/multic/segmentationschool/Codes/wsi_loader_utils.py +++ b/multic/segmentationschool/Codes/wsi_loader_utils.py @@ -1,85 +1,103 @@ -import openslide,glob,os +import glob, os import numpy as np from scipy.ndimage.morphology import binary_fill_holes -import matplotlib.pyplot as plt from skimage.color import rgb2hsv from skimage.filters import gaussian -# from skimage.morphology import binary_dilation, diamond -# import cv2 from tqdm import tqdm from skimage.io import imread,imsave import multiprocessing from joblib import Parallel, delayed +from shapely.geometry import Polygon +from tiffslide import TiffSlide +import random +import glob +import warnings +from joblib import Parallel, delayed +import multiprocessing +from .xml_to_mask_minmax import write_minmax_to_xml +import lxml.etree as ET -def save_thumb(args,slide_loc): - print(slide_loc) - slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1]) - slide=openslide.OpenSlide(slide_loc) - if slideExt =='.scn': - dim_x=int(slide.properties['openslide.bounds-width'])## add to columns - dim_y=int(slide.properties['openslide.bounds-height'])## add to rows - offsetx=int(slide.properties['openslide.bounds-x'])##start column - offsety=int(slide.properties['openslide.bounds-y'])##start row - elif slideExt in ['.ndpi','.svs']: - dim_x, dim_y=slide.dimensions - offsetx=0 - offsety=0 +MPP = {'V42D20-364_XY01_2235505.svs':0.25, + 'V42D20-364_XY04_2240610.svs':0.25, + 'V42N07-339_XY04_F44.svs':0.25, + 'V42N07-395_XY01_235142.svs':0.25, + 'V42N07-395_XY04_235582.svs':0.25, + 'V42N07-399_XY01_3723.svs':0.25, + 'XY01_IU-21-015F.svs':0.50, + 'XY02_IU-21-016F.svs':0.50, + 'XY03_IU-21-019F.svs':0.50, + 'XY04_IU-21-020F.svs':0.50} - # fullSize=slide.level_dimensions[0] - # resRatio= args.chop_thumbnail_resolution - # ds_1=fullSize[0]/resRatio - # ds_2=fullSize[1]/resRatio - # thumbIm=np.array(slide.get_thumbnail((ds_1,ds_2))) - # if slideExt =='.scn': - # xStt=int(offsetx/resRatio) - # xStp=int((offsetx+dim_x)/resRatio) - # yStt=int(offsety/resRatio) - # yStp=int((offsety+dim_y)/resRatio) - # thumbIm=thumbIm[yStt:yStp,xStt:xStp] - # imsave(slide_loc.replace(slideExt,'_thumb.jpeg'),thumbIm) - slide.associated_images['label'].save(slide_loc.replace(slideExt,'_label.png')) - # imsave(slide_loc.replace(slideExt,'_label.png'),slide.associated_images['label']) - - -def get_image_thumbnails(args): - assert args.target is not None, 'Location of images must be provided' - all_slides=[] - for ext in args.wsi_ext.split(','): - all_slides.extend(glob.glob(args.target+'/*'+ext)) - Parallel(n_jobs=multiprocessing.cpu_count())(delayed(save_thumb)(args,slide_loc) for slide_loc in tqdm(all_slides)) - # for slide_loc in tqdm(all_slides): - -class WSIPredictLoader(): - def __init__(self,args, wsi_directory=None, transform=None): - assert wsi_directory is not None, 'location of training svs and xml must be provided' - mask_out_loc=os.path.join(wsi_directory.replace('/TRAINING_data/0','Permanent/Tissue_masks/'),) - if not os.path.exists(mask_out_loc): - os.makedirs(mask_out_loc) - all_slides=[] - for ext in args.wsi_ext.split(','): - all_slides.extend(glob.glob(wsi_directory+'/*'+ext)) - print('Getting slide metadata and usable regions...') - usable_slides=[] - for slide_loc in all_slides: - slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1]) - print("working slide... "+ slideID,end='\r') - - slide=openslide.OpenSlide(slide_loc) - chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc) - mag_x=np.round(float(slide.properties['openslide.mpp-x']),2) - mag_y=np.round(float(slide.properties['openslide.mpp-y']),2) - print(mag_x,mag_y) - usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,'slideExt':slideExt,'slide':slide, - 'chop_array':chop_array,'mag':[mag_x,mag_y]}) - self.usable_slides= usable_slides - self.boxSize40X = args.boxSize - self.boxSize20X = int(args.boxSize)/2 - -class WSITrainingLoader(): - def __init__(self,args, wsi_directory=None, transform=None): +def get_image_meta(i,args): + image_annotation_info={} + image_annotation_info['slide_loc']=i[0] + slide=TiffSlide(image_annotation_info['slide_loc']) + magx=MPP[image_annotation_info['slide_loc'].split('/')[-1]]#np.round(float(slide.properties['tiffslide.mpp-x']),2) + magy=MPP[image_annotation_info['slide_loc'].split('/')[-1]]#np.round(float(slide.properties['tiffslide.mpp-y']),2) + + assert magx == magy + if magx ==0.25: + dx=args.boxSize + dy=args.boxSize + elif magx == 0.5: + dx=int(args.boxSize/2) + dy=int(args.boxSize/2) + else: + print('nonstandard image magnification') + print(slide) + print(magx,magy) + exit() + + image_annotation_info['coordinates']=[i[2][1],i[2][0]] + image_annotation_info['height']=dx + image_annotation_info['width']=dy + image_annotation_info['image_id']=i[1].split('/')[-1].replace('.xml','_'.join(['',str(i[2][1]),str(i[2][0])])) + image_annotation_info['xml_loc']=i[1] + image_annotation_info['file_name']=i[1].split('/')[-1] + slide.close() + + return image_annotation_info + +def train_samples_from_WSI(args,image_coordinates): + + + num_cores=multiprocessing.cpu_count() + + print('Generating detectron2 dictionary format...') + data_list=Parallel(n_jobs=num_cores//2,backend='threading')(delayed(get_image_meta)(i=i, + args=args) for i in tqdm(image_coordinates)) + # print(len(data_list),'this is') + # data_list=[] + # for i in tqdm(image_coordinates): + # data_list.append(get_image_meta(i=i,args=args)) + return data_list + +def WSIGridIterator(wsi_name,choppable_regions,index_x,index_y,region_size,dim_x,dim_y): + wsi_name=os.path.splitext(wsi_name.split('/')[-1])[0] + data_list=[] + for idxy, i in tqdm(enumerate(index_y)): + for idxx, j in enumerate(index_x): + if choppable_regions[idxy, idxx] != 0: + yEnd = min(dim_y,i+region_size) + xEnd = min(dim_x,j+region_size) + xLen=xEnd-j + yLen=yEnd-i + + image_annotation_info={} + image_annotation_info['file_name']='_'.join([wsi_name,str(j),str(i),str(xEnd),str(yEnd)]) + image_annotation_info['height']=yLen + image_annotation_info['width']=xLen + image_annotation_info['image_id']=image_annotation_info['file_name'] + image_annotation_info['xStart']=j + image_annotation_info['yStart']=i + data_list.append(image_annotation_info) + return data_list + +def get_slide_data(args, wsi_directory=None): assert wsi_directory is not None, 'location of training svs and xml must be provided' - mask_out_loc=os.path.join(wsi_directory.replace('/TRAINING_data/0','Permanent/Tissue_masks/'),) + + mask_out_loc=os.path.join(wsi_directory, 'Tissue_masks') if not os.path.exists(mask_out_loc): os.makedirs(mask_out_loc) all_slides=[] @@ -90,34 +108,104 @@ def __init__(self,args, wsi_directory=None, transform=None): usable_slides=[] for slide_loc in all_slides: slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1]) - print("working slide... "+ slideID,end='\r') - - slide=openslide.OpenSlide(slide_loc) - chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc) - mag_x=np.round(float(slide.properties['openslide.mpp-x']),2) - mag_y=np.round(float(slide.properties['openslide.mpp-y']),2) - print(mag_x,mag_y) - usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,'slideExt':slideExt,'slide':slide, - 'chop_array':chop_array,'mag':[mag_x,mag_y]}) - self.usable_slides= usable_slides - self.boxSize40X = args.boxSize - self.boxSize20X = int(args.boxSize)/2 + xmlpath=slide_loc.replace(slideExt,'.xml') + if os.path.isfile(xmlpath): + write_minmax_to_xml(xmlpath) + + print("Gathering slide data ... "+ slideID,end='\r') + slide =TiffSlide(slide_loc) + chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc) + + mag_x=MPP[slideID+slideExt]#np.round(float(slide.properties['tiffslide.mpp-x']),2) + mag_y=MPP[slideID+slideExt]#np.round(float(slide.properties['tiffslide.mpp-y']),2) + slide.close() + tree = ET.parse(xmlpath) + root = tree.getroot() + balance_classes=args.balanceClasses.split(',') + classNums={} + for b in balance_classes: + classNums[b]=0 + # balance_annotations={} + for Annotation in root.findall("./Annotation"): + + annotationID = Annotation.attrib['Id'] + if annotationID=='7': + print(xmlpath) + exit() + if annotationID in classNums.keys(): + + classNums[annotationID]=len(Annotation.findall("./*/Region")) + else: + pass + + usable_slides.append({'slide_loc':slide_loc,'slideID':slideID, + 'chop_array':chop_array,'num_regions':len(chop_array),'mag':[mag_x,mag_y], + 'xml_loc':xmlpath,'annotations':classNums,'root':root + }) + else: + print('\n') + print('no annotation XML file found for:') + print(slideID) + exit() print('\n') + return usable_slides + +def get_random_chops(slide_idx,usable_slides,region_size): + # chops=[] + choplen=len(slide_idx) + chops=Parallel(n_jobs=multiprocessing.cpu_count(),backend='threading')(delayed(get_chop_data)(idx=idx, + usable_slides=usable_slides,region_size=region_size) for idx in tqdm(slide_idx)) + return chops + + +def get_chop_data(idx,usable_slides,region_size): + if random.random()>0.5: + randSelect=random.randrange(0,usable_slides[idx]['num_regions']) + chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'], + usable_slides[idx]['chop_array'][randSelect]] + else: + # print(list(usable_slides[idx]['annotations'].values())) + if sum(usable_slides[idx]['annotations'].values())==0: + randSelect=random.randrange(0,usable_slides[idx]['num_regions']) + chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'], + usable_slides[idx]['chop_array'][randSelect]] + else: + classIDs=list(usable_slides[idx]['annotations'].keys()) + classSamples=random.sample(classIDs,len(classIDs)) + for c in classSamples: + if usable_slides[idx]['annotations'][c]==0 or c == '5': + continue + else: + sampledRegionID=random.choice([r.attrib['Id'] for r in usable_slides[idx]['root'].find("./Annotation[@Id='{}']/Regions".format(c)).findall('Region')]) + #sampledRegionID=random.randrange(1,usable_slides[idx]['annotations'][c]+1) + break + + + Verts = usable_slides[idx]['root'].findall("./Annotation[@Id='{}']/Regions/Region[@Id='{}']/Vertices/Vertex".format(c,sampledRegionID)) + + centroid = (Polygon([(int(float(k.attrib['X'])),int(float(k.attrib['Y']))) for k in Verts]).centroid) + + randVertX=int(centroid.x)-region_size//2 + randVertY=int(centroid.y)-region_size//2 + + chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'], + [randVertY,randVertX]] + + + return chopData def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): slide_regions=[] choppable_regions_list=[] - downsample = int(args.downsampleRate**.5) #down sample for each dimension region_size = int(args.boxSize*(downsample)) #Region size before downsampling - step = int(region_size*(1-args.overlap_percent)) #Step size before downsampling - + step = int(region_size*(1-args.overlap_rate)) #Step size before downsampling if slideExt =='.scn': - dim_x=int(slide.properties['openslide.bounds-width'])## add to columns - dim_y=int(slide.properties['openslide.bounds-height'])## add to rows - offsetx=int(slide.properties['openslide.bounds-x'])##start column - offsety=int(slide.properties['openslide.bounds-y'])##start row + dim_x=int(slide.properties['tiffslide.bounds-width'])## add to columns + dim_y=int(slide.properties['tiffslide.bounds-height'])## add to rows + offsetx=int(slide.properties['tiffslide.bounds-x'])##start column + offsety=int(slide.properties['tiffslide.bounds-y'])##start row index_y=np.array(range(offsety,offsety+dim_y,step)) index_x=np.array(range(offsetx,offsetx+dim_x,step)) index_y[-1]=(offsety+dim_y)-step @@ -135,7 +223,9 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): resRatio= args.chop_thumbnail_resolution ds_1=fullSize[0]/resRatio ds_2=fullSize[1]/resRatio - if args.get_new_tissue_masks: + out_mask_name=os.path.join(mask_out_loc,'_'.join([slideID,slideExt[1:]+'.png'])) + if not os.path.isfile(out_mask_name) or args.get_new_tissue_masks: + print(out_mask_name) thumbIm=np.array(slide.get_thumbnail((ds_1,ds_2))) if slideExt =='.scn': xStt=int(offsetx/resRatio) @@ -143,40 +233,18 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): yStt=int(offsety/resRatio) yStp=int((offsety+dim_y)/resRatio) thumbIm=thumbIm[yStt:yStp,xStt:xStp] - # plt.imshow(thumbIm) - # plt.show() - # input() - # plt.imshow(thumbIm) - # plt.show() - - out_mask_name=os.path.join(mask_out_loc,'_'.join([slideID,slideExt[1:]+'.png'])) - - - if not args.get_new_tissue_masks: - try: - binary=(imread(out_mask_name)/255).astype('bool') - except: - print('failed to load mask for '+ out_mask_name) - print('please set get_new_tissue masks to True') - exit() - # if slideExt =='.scn': - # choppable_regions=np.zeros((len(index_x),len(index_y))) - # elif slideExt in ['.ndpi','.svs']: choppable_regions=np.zeros((len(index_y),len(index_x))) - else: - print(out_mask_name) - # if slideExt =='.scn': - # choppable_regions=np.zeros((len(index_x),len(index_y))) - # elif slideExt in ['.ndpi','.svs']: - choppable_regions=np.zeros((len(index_y),len(index_x))) - hsv=rgb2hsv(thumbIm) g=gaussian(hsv[:,:,1],5) binary=(g>0.05).astype('bool') binary=binary_fill_holes(binary) imsave(out_mask_name.replace('.png','.jpeg'),thumbIm) - imsave(out_mask_name,binary.astype('uint8')*255) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + imsave(out_mask_name,binary.astype('uint8')*255) + binary=(imread(out_mask_name)/255).astype('bool') + choppable_regions=np.zeros((len(index_y),len(index_x))) chop_list=[] for idxy,yi in enumerate(index_y): for idxx,xj in enumerate(index_x): @@ -185,31 +253,13 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): xStart = int(np.round((xj-offsetx)/resRatio)) xStop = int(np.round(((xj-offsetx)+args.boxSize)/resRatio)) box_total=(xStop-xStart)*(yStop-yStart) - if slideExt =='.scn': - # print(xStart,xStop,yStart,yStop) - # print(np.sum(binary[xStart:xStop,yStart:yStop]),args.white_percent,box_total) - # plt.imshow(binary[xStart:xStop,yStart:yStop]) - # plt.show() - if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total): - - choppable_regions[idxy,idxx]=1 - chop_list.append([index_y[idxy],index_x[idxx]]) - - elif slideExt in ['.ndpi','.svs']: - if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total): - choppable_regions[idxy,idxx]=1 - chop_list.append([index_y[idxy],index_x[idxx]]) - - imsave(out_mask_name.replace('.png','_chopregions.png'),choppable_regions.astype('uint8')*255) - - # plt.imshow(choppable_regions) - # plt.show() - # choppable_regions_list.extend(chop_list) - # plt.subplot(131) - # plt.imshow(thumbIm) - # plt.subplot(132) - # plt.imshow(binary) - # plt.subplot(133) - # plt.imshow(choppable_regions) - # plt.show() - return chop_list + + if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total): + choppable_regions[idxy,idxx]=1 + chop_list.append([index_y[idxy],index_x[idxx]]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + imsave(out_mask_name.replace('.png','_chopregions.png'),choppable_regions.astype('uint8')*255) + + return chop_list \ No newline at end of file diff --git a/multic/segmentationschool/segmentation_school.py b/multic/segmentationschool/segmentation_school.py index 50e32c1..a0d5b42 100644 --- a/multic/segmentationschool/segmentation_school.py +++ b/multic/segmentationschool/segmentation_school.py @@ -174,7 +174,7 @@ def savetime(args, starttime): ##### Args for training / prediction #################################################### parser.add_argument('--gpu_num', dest='gpu_num', default=2 ,type=int, help='number of GPUs avalable') - parser.add_argument('--gpu', dest='gpu', default=0 ,type=int, + parser.add_argument('--gpu', dest='gpu', default="" ,type=str, help='GPU to use for prediction') parser.add_argument('--iteration', dest='iteration', default='none' ,type=str, help='Which iteration to use for prediction') @@ -188,6 +188,18 @@ def savetime(args, starttime): help='number of classes present in the High res training data [USE ONLY IF DIFFERENT FROM LOW RES]') parser.add_argument('--modelfile', dest='modelfile', default=None ,type=str, help='the desired model file to use for training or prediction') + parser.add_argument('--init_modelfile', dest='init_modelfile', default=None ,type=str, + help='the desired model file to use for training or prediction') + parser.add_argument('--eval_period', dest='eval_period', default=1000 ,type=int, + help='Validation Period') + parser.add_argument('--batch_size', dest='batch_size', default=4 ,type=int, + help='Size of batches for training high resolution CNN') + parser.add_argument('--train_steps', dest='train_steps', default=1000 ,type=int, + help='Size of batches for training high resolution CNN') + parser.add_argument('--training_data_dir', dest='training_data_dir', default=os.getcwd(),type=str, + help='Training Data Folder') + parser.add_argument('--overlap_rate', dest='overlap_rate', default=0.5 ,type=float, + help='overlap percentage of high resolution blocks [0-1]') ### Params for cutting wsi ### #White level cutoff @@ -202,10 +214,14 @@ def savetime(args, starttime): help='size of low resolution blocks') parser.add_argument('--downsampleRateLR', dest='downsampleRateLR', default=16 ,type=int, help='reduce image resolution to 1/downsample rate') + parser.add_argument('--get_new_tissue_masks', dest='get_new_tissue_masks', default=False,type=str2bool, + help="Don't load usable tisse regions from disk, create new ones") + parser.add_argument('--downsampleRate', dest='downsampleRate', default=1 ,type=int, + help='reduce image resolution to 1/downsample rate') #High resolution parameters parser.add_argument('--overlap_percentHR', dest='overlap_percentHR', default=0 ,type=float, help='overlap percentage of high resolution blocks [0-1]') - parser.add_argument('--boxSize', dest='boxSize', default=2048 ,type=int, + parser.add_argument('--boxSize', dest='boxSize', default=1200 ,type=int, help='size of high resolution blocks') parser.add_argument('--downsampleRateHR', dest='downsampleRateHR', default=1 ,type=int, help='reduce image resolution to 1/downsample rate') @@ -226,7 +242,8 @@ def savetime(args, starttime): help='Gaussian variance defining bounds on Hue shift for HSV color augmentation') parser.add_argument('--lbound', dest='lbound', default=0.025 ,type=float, help='Gaussian variance defining bounds on L* gamma shift for color augmentation [alters brightness/darkness of image]') - + parser.add_argument('--balanceClasses', dest='balanceClasses', default='3,4,5,6',type=str, + help="which classes to balance during training") ### Params for training networks ### #Low resolution hyperparameters parser.add_argument('--CNNbatch_sizeLR', dest='CNNbatch_sizeLR', default=2 ,type=int, @@ -280,6 +297,8 @@ def savetime(args, starttime): help='padded region for low resolution region extraction') parser.add_argument('--show_interstitium', dest='show_interstitium', default=True ,type=str2bool, help='padded region for low resolution region extraction') + parser.add_argument('--num_workers', dest='num_workers', default=1 ,type=int, + help='Number of workers for data loader') diff --git a/multic/segmentationschool/slurm_training.sh b/multic/segmentationschool/slurm_training.sh new file mode 100644 index 0000000..e697617 --- /dev/null +++ b/multic/segmentationschool/slurm_training.sh @@ -0,0 +1,35 @@ +#!/bin/sh +#SBATCH --account=pinaki.sarder +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=10 +#SBATCH --mem-per-cpu=16gb +#SBATCH --partition=gpu +#SBATCH --gpus=a100:2 +#SBATCH --time=72:00:00 +#SBATCH --output=./slurm_log.out +#SBATCH --job-name="multic_training" +echo "SLURM_JOBID="$SLURM_JOBID +echo "SLURM_JOB_NODELIST="$SLURM_JOB_NODELIST +echo "SLURM_NNODES="$SLURM_NNODES +echo "SLURMTMPDIR="$SLURMTMPDIR + +echo "working directory = "$SLURM_SUBMIT_DIR +ulimit -s unlimited +module load singularity +ls +ml + +# Add your userid here: +USER=sayat.mimar +# Add the name of the folder containing WSIs here +PROJECT=multic_segment + +CODESDIR=/blue/pinaki.sarder/sayat.mimar/Multi-Compartment-Segmentation/multic/segmentationschool + +DATADIR=$CODESDIR/TRAINING_data +MODELDIR=$CODESDIR/pretrained_model + +CONTAINER=$CODESDIR/multic_segment.sif +CUDA_LAUNCH_BLOCKING=1 +singularity exec --nv -B $(pwd):/exec/,$DATADIR/:/data,$MODELDIR/:/model/ $CONTAINER python3 /exec/segmentation_school.py --option train --base_dir $CODESDIR --init_modelfile $MODELDIR/model_final.pth --training_data_dir $CODESDIR/TRAINING_data/first --train_steps 100000 --eval_period 25000 --num_workers 10