diff --git a/Dockerfile b/Dockerfile index 4ca5afb..505d751 100644 --- a/Dockerfile +++ b/Dockerfile @@ -134,6 +134,7 @@ WORKDIR $mc_path/multic/cli RUN python -m slicer_cli_web.cli_list_entrypoint --list_cli RUN python -m slicer_cli_web.cli_list_entrypoint MultiCompartmentSegment --help RUN python -m slicer_cli_web.cli_list_entrypoint FeatureExtraction --help +RUN python -m slicer_cli_web.cli_list_entrypoint MultiCompartmentTrain --help ENTRYPOINT ["/bin/bash", "docker-entrypoint.sh"] diff --git a/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml b/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml index 9c13982..005f086 100644 --- a/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml +++ b/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml @@ -4,7 +4,7 @@ Multi Compartment Segmentation Segments multi-level structures from a whole-slide image 0.1.0 - https://github.com/SarderLab/deeplab-WSI + https://github.com/SarderLab/Multi-Compartment-Segmentation Apache 2.0 Sayat Mimar (UFL) This work is part of efforts in digital pathology by the Sarder Lab: UFL. diff --git a/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.py b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.py new file mode 100644 index 0000000..0c31306 --- /dev/null +++ b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.py @@ -0,0 +1,173 @@ +import os +import sys +from glob import glob +import girder_client +from ctk_cli import CLIArgumentParser + +sys.path.append("..") +from segmentationschool.utils.mask_to_xml import xml_create, xml_add_annotation, xml_add_region, xml_save +from segmentationschool.utils.xml_to_mask import write_minmax_to_xml + +NAMES = ['cortical_interstitium','medullary_interstitium','non_globally_sclerotic_glomeruli','globally_sclerotic_glomeruli','tubules','arteries/arterioles'] + + +def main(args): + + folder = args.training_data_dir + base_dir_id = folder.split('/')[-2] + _ = os.system("printf '\nUsing data from girder_client Folder: {}\n'".format(folder)) + + _ = os.system("printf '\n---\n\nFOUND: [{}]\n'".format(args.init_modelfile)) + + gc = girder_client.GirderClient(apiUrl=args.girderApiUrl) + gc.setToken(args.girderToken) + # get files in folder + files = gc.listItem(base_dir_id) + xml_color=[65280]*(len(NAMES)+1) + cwd = os.getcwd() + print(cwd) + os.chdir(cwd) + + tmp = folder + + slides_used = [] + ignore_label = len(NAMES)+1 + for file in files: + slidename = file['name'] + _ = os.system("printf '\n---\n\nFOUND: [{}]\n'".format(slidename)) + skipSlide = 0 + + # get annotation + item = gc.getItem(file['_id']) + annot = gc.get('/annotation/item/{}'.format(item['_id']), parameters={'sort': 'updated'}) + annot.reverse() + annot = list(annot) + _ = os.system("printf '\tfound [{}] annotation layers...\n'".format(len(annot))) + + # create root for xml file + xmlAnnot = xml_create() + + # all compartments + for class_,compart in enumerate(NAMES): + + compart = compart.replace(' ','') + class_ +=1 + # add layer to xml + xmlAnnot = xml_add_annotation(Annotations=xmlAnnot, xml_color=xml_color, annotationID=class_) + + # test all annotation layers in order created + for iter,a in enumerate(annot): + + + try: + # check for annotation layer by name + a_name = a['annotation']['name'].replace(' ','') + except: + a_name = None + + if a_name == compart: + # track all layers present + skipSlide +=1 + + pointsList = [] + + # load json data + _ = os.system("printf '\tloading annotation layer: [{}]\n'".format(compart)) + + a_data = a['annotation']['elements'] + + for data in a_data: + pointList = [] + points = data['points'] + for point in points: + pt_dict = {'X': round(point[0]), 'Y': round(point[1])} + pointList.append(pt_dict) + pointsList.append(pointList) + + # write annotations to xml + for i in range(len(pointsList)): + pointList = pointsList[i] + xmlAnnot = xml_add_region(Annotations=xmlAnnot, pointList=pointList, annotationID=class_) + + # print(a['_version'], a['updated'], a['created']) + break + + if skipSlide != len(NAMES): + _ = os.system("printf '\tThis slide is missing annotation layers\n'") + _ = os.system("printf '\tSKIPPING SLIDE...\n'") + del xmlAnnot + continue # correct layers not present + # compart = 'ignore_label' + # # test all annotation layers in order created + # for iter,a in enumerate(annot): + # try: + # # check for annotation layer by name + # a_name = a['annotation']['name'].replace(' ','') + # except: + # a_name = None + # if a_name == compart: + # pointsList = [] + # # load json data + # _ = os.system("printf '\tloading annotation layer: [{}]\n'".format(compart)) + # a_data = a['annotation']['elements'] + # for data in a_data: + # pointList = [] + # if data['type'] == 'polyline': + # points = data['points'] + # elif data['type'] == 'rectangle': + # center = data['center'] + # width = data['width']/2 + # height = data['height']/2 + # points = [[ center[0]-width, center[1]-width ],[ center[0]+width, center[1]+width ]] + # for point in points: + # pt_dict = {'X': round(point[0]), 'Y': round(point[1])} + # pointList.append(pt_dict) + # pointsList.append(pointList) + # # write annotations to xml + + # for i in range(len(pointsList)): + # pointList = pointsList[i] + # xmlAnnot = xml_add_region(Annotations=xmlAnnot, pointList=pointList, annotationID=ignore_label) + # break + + # include slide and fetch annotations + _ = os.system("printf '\tFETCHING SLIDE...\n'") + os.rename('{}/{}'.format(folder, slidename), '{}/{}'.format(tmp, slidename)) + slides_used.append(slidename) + + xml_path = '{}/{}.xml'.format(tmp, os.path.splitext(slidename)[0]) + _ = os.system("printf '\tsaving a created xml annotation file: [{}]\n'".format(xml_path)) + xml_save(Annotations=xmlAnnot, filename=xml_path) + write_minmax_to_xml(xml_path) # to avoid trying to write to the xml from multiple workers + del xmlAnnot + os.system("ls -lh '{}'".format(tmp)) + + trainlogdir=os.path.join(tmp, 'output') + if not os.path.exists(trainlogdir): + os.makedirs(trainlogdir) + + _ = os.system("printf '\ndone retriving data...\nstarting training...\n\n'") + + + cmd = "python3 ../segmentationschool/segmentation_school.py --option {} --training_data_dir {} --init_modelfile {} --gpu {} --train_steps {} --num_workers {} --girderApiUrl {} --girderToken {}".format('train', tmp.replace(' ', '\ '), args.init_modelfile, args.gpu, args.training_steps, args.num_workers, args.girderApiUrl, args.girderToken) + print(cmd) + sys.stdout.flush() + os.system(cmd) + + os.listdir(trainlogdir) + os.chdir(trainlogdir) + os.system('pwd') + os.system('ls -lh') + + filelist = glob('*.pth') + latest_model = max(filelist, key=os.path.getmtime) + + _ = os.system("printf '\n{}\n'".format(latest_model)) + os.rename(latest_model, args.output_model) + + _ = os.system("printf '\nDone!\n\n'") + + + +if __name__ == "__main__": + main(CLIArgumentParser().parse_args()) diff --git a/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.xml b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.xml new file mode 100644 index 0000000..9f70a39 --- /dev/null +++ b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.xml @@ -0,0 +1,75 @@ + + + HistomicsTK + Multi Compartment Training + Trains Multi compartment segmentation model + 0.1.0 + https://github.com/SarderLab/Multi-Compartment-Segmentation + Apache 2.0 + Sayat Mimar (UFL) + This work is part of efforts in digital pathology by the Sarder Lab: UFL. + + + Input/output parameters + + training_data_dir + + Base Directory for the model + input + 0 + + + init_modelfile + + Trained model file + input + 1 + + + gpu + + A comma separated list of the GPU IDs that will be made avalable for training + 0,1 + 2 + + + training_steps + + The number of steps used for network training. The network will see [steps * batch size] image patches during training + 5000 + 3 + + + num_workers + + Number of workers for Dataloader + 0 + 4 + + + output_model + + Select the name of the output model file produced. By default this will be saved in your Private folder. + output + 5 + + + + + A Girder API URL and token for Girder client + + girderApiUrl + api-url + + A Girder API URL (e.g., https://girder.example.com:443/api/v1) + + + + girderToken + token + + A Girder token + + + + diff --git a/multic/cli/slicer_cli_list.json b/multic/cli/slicer_cli_list.json index 1189b00..b6a048f 100644 --- a/multic/cli/slicer_cli_list.json +++ b/multic/cli/slicer_cli_list.json @@ -4,5 +4,8 @@ }, "FeatureExtraction": { "type" : "python" + }, + "MultiCompartmentTrain": { + "type" : "python" } } diff --git a/multic/segmentationschool/Codes/IterativeTraining_1X.py b/multic/segmentationschool/Codes/IterativeTraining_1X.py index 9c98890..0b92135 100644 --- a/multic/segmentationschool/Codes/IterativeTraining_1X.py +++ b/multic/segmentationschool/Codes/IterativeTraining_1X.py @@ -1,677 +1,343 @@ -import os, sys, cv2, time, random, warnings, multiprocessing#json,# detectron2 +import os,cv2, time, random, multiprocessing,copy +from skimage.color import rgb2hsv,hsv2rgb,rgb2lab,lab2rgb import numpy as np -import matplotlib.pyplot as plt -import lxml.etree as ET -from matplotlib import path -from skimage.transform import resize -from skimage.io import imread, imsave -import glob -from .getWsi import getWsi - -from .xml_to_mask2 import get_supervision_boxes, regions_in_mask_dots, get_vertex_points_dots, masks_from_points, restart_line -from joblib import Parallel, delayed -from shutil import move +from tiffslide import TiffSlide +from .xml_to_mask_minmax import xml_to_mask # from generateTrainSet import generateDatalists -#from subprocess import call -#from .get_choppable_regions import get_choppable_regions -from PIL import Image - +import logging from detectron2.utils.logger import setup_logger +from skimage import exposure + setup_logger() from detectron2 import model_zoo -from detectron2.engine import DefaultPredictor,DefaultTrainer +from detectron2.engine import DefaultTrainer from detectron2.config import get_cfg -from detectron2.utils.visualizer import Visualizer#,ColorMode -from detectron2.data import MetadataCatalog, DatasetCatalog -#from detectron2.structures import BoxMode -from .get_dataset_list import HAIL2Detectron, samples_from_json, samples_from_json_mini -#from detectron2.checkpoint import DetectionCheckpointer -#from detectron2.modeling import build_model - -""" - -Code for - cutting / augmenting / training CNN - -This uses WSI and XML files to train 2 neural networks for semantic segmentation - of histopath tissue via human in the loop training - -""" - +# from detectron2.data import MetadataCatalog, DatasetCatalog +from detectron2.data import detection_utils as utils +import detectron2.data.transforms as T +from detectron2.structures import BoxMode +from detectron2.data import (DatasetCatalog, + MetadataCatalog, + build_detection_test_loader, + build_detection_train_loader, +) +from detectron2.config import configurable +from typing import List, Optional, Union +import torch + +# sys.append("..") +from .wsi_loader_utils import train_samples_from_WSI, get_slide_data, get_random_chops +from imgaug import augmenters as iaa + + +global seq +seq = iaa.Sequential([ + iaa.Sometimes(0.5,iaa.OneOf([ + iaa.AddElementwise((-15,15),per_channel=0.5), + iaa.ImpulseNoise(0.05),iaa.CoarseDropout(0.02, size_percent=0.5)])), + iaa.Sometimes(0.5,iaa.OneOf([iaa.GaussianBlur(sigma=(0, 3.0)), + iaa.Sharpen(alpha=(0.0, 1.0), lightness=(0.75, 2.0))])) +]) #Record start time totalStart=time.time() def IterateTraining(args): - ## calculate low resolution block params - downsampleLR = int(args.downsampleRateLR**.5) #down sample for each dimension - region_sizeLR = int(args.boxSizeLR*(downsampleLR)) #Region size before downsampling - stepLR = int(region_sizeLR*(1-args.overlap_percentLR)) #Step size before downsampling - ## calculate low resolution block params - downsampleHR = int(args.downsampleRateHR**.5) #down sample for each dimension - region_sizeHR = int(args.boxSizeHR*(downsampleHR)) #Region size before downsampling - stepHR = int(region_sizeHR*(1-args.overlap_percentHR)) #Step size before downsampling - - global classNum_HR,classEnumLR,classEnumHR + + region_size = int(args.boxSize) #Region size before downsampling + dirs = {'imExt': '.jpeg'} dirs['basedir'] = args.base_dir dirs['maskExt'] = '.png' - dirs['modeldir'] = '/MODELS/' - dirs['tempdirLR'] = '/TempLR/' - dirs['tempdirHR'] = '/TempHR/' - dirs['pretraindir'] = '/Deeplab_network/' - dirs['training_data_dir'] = '/TRAINING_data/' - dirs['model_init'] = 'deeplab_resnet.ckpt' - dirs['project']= '/' + args.project - dirs['data_dir_HR'] = args.base_dir +'/' + args.project + '/Permanent/HR/' - dirs['data_dir_LR'] = args.base_dir +'/' +args.project + '/Permanent/LR/' + dirs['training_data_dir'] = args.training_data_dir - ##All folders created, initiate WSI loading by human - #raw_input('Please place WSIs in ') - ##Check iteration session - currentmodels=os.listdir(dirs['basedir'] + dirs['project'] + dirs['modeldir']) - print('Handcoded iteration') - # currentAnnotationIteration=check_model_generation(dirs) - currentAnnotationIteration=2 - print('Current training session is: ' + str(currentAnnotationIteration)) - - ##Create objects for storing class distributions - annotatedXMLs=glob.glob(dirs['basedir'] + dirs['project'] + dirs['training_data_dir'] + str(currentAnnotationIteration) + '/*.xml') - classes=[] + print('Handcoded iteration') - if args.classNum == 0: - for xml in annotatedXMLs: - classes.append(get_num_classes(xml)) - classNum_HR = max(classes) + os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu + os.environ["CUDA_LAUNCH_BLOCKING"] ='1' + + organType='kidney' + print('Organ meta being set to... '+ organType) + if organType=='liver': + classnames=['Background','BD','A'] + isthing=[0,1,1] + xml_color = [[0,255,0], [0,255,255], [0,0,255]] + tc=['BD','AT'] + sc=['Ob','B'] + elif organType =='kidney': + classnames=['interstitium','medulla','glomerulus','sclerotic glomerulus','tubule','arterial tree'] + classes={} + isthing=[0,0,1,1,1,1] + xml_color = [[0,255,0], [0,255,255], [255,255,0],[0,0,255], [255,0,0], [0,128,255]] + tc=['G','SG','T','A'] + sc=['Ob','I','M','B'] else: - classNum_LR = args.classNum - if args.classNum_HR != 0: - classNum_HR = args.classNum_HR - else: - classNum_HR = classNum_LR - - classNum_HR=args.classNum - - ##for all WSIs in the initiating directory: - if args.chop_data == 'True': - print('Chopping') - - start=time.time() - size_data=[] - - for xmlID in annotatedXMLs: - - #Get unique name of WSI - fileID=xmlID.split('/')[-1].split('.xml')[0] - print('-----------------'+fileID+'----------------') - #create memory addresses for wsi files - for ext in [args.wsi_ext]: - wsiID=dirs['basedir'] + dirs['project']+ dirs['training_data_dir'] + str(currentAnnotationIteration) +'/'+ fileID + ext - - #Ensure annotations exist - if os.path.isfile(wsiID)==True: - break - - - #Load openslide information about WSI - if ext != '.tif': - slide=getWsi(wsiID) - #WSI level 0 dimensions (largest size) - dim_x,dim_y=slide.dimensions - else: - im = Image.open(wsiID) - dim_x, dim_y=im.size - location=[0,0] - size=[dim_x,dim_y] - tree = ET.parse(xmlID) - root = tree.getroot() - box_supervision_layers=['8'] - # calculate region bounds - global_bounds = {'x_min' : location[0], 'y_min' : location[1], 'x_max' : location[0] + size[0], 'y_max' : location[1] + size[1]} - local_bounds = get_supervision_boxes(root,box_supervision_layers) - num_cores = multiprocessing.cpu_count() - Parallel(n_jobs=num_cores)(delayed(chop_suey_bounds)(args=args,wsiID=wsiID, - dirs=dirs,lb=lb,xmlID=xmlID,box_supervision_layers=box_supervision_layers) for lb in tqdm(local_bounds)) - # for lb in tqdm(local_bounds): - - # size_data.extend(image_sizes) - - ''' - wsi_mask=xml_to_mask(xmlID, [0,0], [dim_x,dim_y]) + print('Provided organType not in supported types: kidney, liver') - #Enumerate cpu core count - num_cores = multiprocessing.cpu_count() + classNum=len(tc)+len(sc)-1 + print('Number classes: '+ str(classNum)) + classes={} - #Generate iterators for parallel chopping of WSIs in high resolution - #index_yHR=range(30240,dim_y-stepHR,stepHR) - #index_xHR=range(840,dim_x-stepHR,stepHR) - index_yHR=range(0,dim_y,stepHR) - index_xHR=range(0,dim_x,stepHR) - index_yHR[-1]=dim_y-stepHR - index_xHR[-1]=dim_x-stepHR - #Create memory address for chopped images high resolution - outdirHR=dirs['basedir'] + dirs['project'] + dirs['tempdirHR'] + for idx,c in enumerate(classnames): + classes[idx]={'isthing':isthing[idx],'color':xml_color[idx]} - #Perform high resolution chopping in parallel and return the number of - #images in each of the labeled classes - chop_regions=get_choppable_regions(wsi=wsiID, - index_x=index_xHR,index_y=index_yHR,boxSize=region_sizeHR,white_percent=args.white_percent) - Parallel(n_jobs=num_cores)(delayed(return_region)(args=args, - wsi_mask=wsi_mask, wsiID=wsiID, - fileID=fileID, yStart=j, xStart=i, idxy=idxy, - idxx=idxx, downsampleRate=args.downsampleRateHR, - outdirT=outdirHR, region_size=region_sizeHR, - dirs=dirs, chop_regions=chop_regions,classNum_HR=classNum_HR) for idxx,i in enumerate(index_xHR) for idxy,j in enumerate(index_yHR)) - - #for idxx,i in enumerate(index_xHR): - # for idxy,j in enumerate(index_yHR): - # if chop_regions[idxy,idxx] != 0: - # return_region(args=args,xmlID=xmlID, wsiID=wsiID, fileID=fileID, yStart=j, xStart=i,idxy=idxy, idxx=idxx, - # downsampleRate=args.downsampleRateHR,outdirT=outdirHR, region_size=region_sizeHR, dirs=dirs, - # chop_regions=chop_regions,classNum_HR=classNum_HR) - # else: - # print('pass') - - - # exit() - print('Time for WSI chopping: ' + str(time.time()-start)) - - classEnumHR=np.ones([classNum_HR,1])*classNum_HR - - ##High resolution augmentation - #Enumerate high resolution class distribution - classDistHR=np.zeros(len(classEnumHR)) - for idx,value in enumerate(classEnumHR): - classDistHR[idx]=value/sum(classEnumHR) - print(classDistHR) - #Define number of augmentations per class - - moveimages(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/regions/', dirs['basedir']+dirs['project'] + '/Permanent/HR/regions/') - moveimages(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/masks/',dirs['basedir']+dirs['project'] + '/Permanent/HR/masks/') - - - #Total time - print('Time for high resolution augmenting: ' + str((time.time()-totalStart)/60) + ' minutes.') - ''' - - # with open('sizes.csv','w',newline='') as myfile: - # wr=csv.writer(myfile,quoting=csv.QUOTE_ALL) - # wr.writerow(size_data) - # pretrain_HR=get_pretrain(currentAnnotationIteration,'/HR/',dirs) - - modeldir_HR = dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration+1) + '/HR/' - - - ##### HIGH REZ ARGS ##### - dirs['outDirAIHR']=dirs['basedir']+'/'+dirs['project'] + '/Permanent/HR/regions/' - dirs['outDirAMHR']=dirs['basedir']+'/'+dirs['project'] + '/Permanent/HR/masks/' - - - numImagesHR=len(glob.glob(dirs['outDirAIHR'] + '*' + dirs['imExt'])) - - numStepsHR=(args.epoch_HR*numImagesHR)/ args.CNNbatch_sizeHR - - - #----------------------------------------------------------------------------------------- - # os.environ["CUDA_VISIBLE_DEVICES"]='0' - os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) - # img_dir='/hdd/bg/Detectron2/chop_detectron/Permanent/HR' - - img_dir=dirs['outDirAIHR'] - classnames=['Background','BD','A'] - isthing=[0,1,1] - xml_color = [[0,255,0], [0,255,255], [0,0,255]] - - rand_sample=True - - json_file=img_dir+'/detectron_train.json' - HAIL2Detectron(img_dir,rand_sample,json_file,classnames,isthing,xml_color) - tc=['BD','AT'] - sc=['I','B'] - #### From json - DatasetCatalog.register("my_dataset", lambda:samples_from_json(json_file,rand_sample)) + num_images=args.batch_size*args.train_steps + # slide_idxs=train_dset.get_random_slide_idx(num_images) + usable_slides=get_slide_data(args, wsi_directory = dirs['training_data_dir']) + print('Number of slides:', len(usable_slides)) + usable_idx=range(0,len(usable_slides)) + slide_idxs=random.choices(usable_idx,k=num_images) + image_coordinates=get_random_chops(slide_idxs,usable_slides,region_size) + + DatasetCatalog.register("my_dataset", lambda:train_samples_from_WSI(args,image_coordinates)) MetadataCatalog.get("my_dataset").set(thing_classes=tc) MetadataCatalog.get("my_dataset").set(stuff_classes=sc) + - seg_metadata=MetadataCatalog.get("my_dataset") - - - # new_list = DatasetCatalog.get("my_dataset") - # print(len(new_list)) - # for d in random.sample(new_list, 100): - # - # img = cv2.imread(d["file_name"]) - # visualizer = Visualizer(img[:, :, ::-1],metadata=seg_metadata, scale=0.5) - # out = visualizer.draw_dataset_dict(d) - # cv2.namedWindow("output", cv2.WINDOW_NORMAL) - # cv2.imshow("output",out.get_image()[:, :, ::-1]) - # cv2.waitKey(0) # waits until a key is pressed - # cv2.destroyAllWindows() - # exit() + _ = os.system("printf '\nTraining starts...\n'") cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml")) cfg.DATASETS.TRAIN = ("my_dataset") - cfg.DATASETS.TEST = () - num_cores = multiprocessing.cpu_count() - cfg.DATALOADER.NUM_WORKERS = num_cores-3 - # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml") # Let training initialize from model zoo - cfg.MODEL.WEIGHTS = os.path.join('/hdd/bg/Detectron2/HAIL_Detectron2/liver/MODELS/0/HR', "model_final.pth") + #cfg.TEST.EVAL_PERIOD=args.eval_period + #cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.50 + #num_cores = multiprocessing.cpu_count() + cfg.DATALOADER.NUM_WORKERS = args.num_workers + + if args.init_modelfile: + cfg.MODEL.WEIGHTS = args.init_modelfile + else: + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml") # Let training initialize from model zoo - cfg.SOLVER.IMS_PER_BATCH = 10 + cfg.SOLVER.IMS_PER_BATCH = args.batch_size - # cfg.SOLVER.BASE_LR = 0.02 # pick a good LR - # cfg.SOLVER.LR_policy='steps_with_lrs' - # cfg.SOLVER.MAX_ITER = 50000 - # cfg.SOLVER.STEPS = [30000,40000] - # # cfg.SOLVER.STEPS = [] - # cfg.SOLVER.LRS = [0.002,0.0002] - cfg.SOLVER.BASE_LR = 0.002 # pick a good LR cfg.SOLVER.LR_policy='steps_with_lrs' - cfg.SOLVER.MAX_ITER = 200000 - cfg.SOLVER.STEPS = [150000,180000] - # cfg.SOLVER.STEPS = [] - cfg.SOLVER.LRS = [0.0002,0.00002] - - # cfg.INPUT.CROP.ENABLED = True - # cfg.INPUT.CROP.TYPE='absolute' - # cfg.INPUT.CROP.SIZE=[100,100] - cfg.MODEL.BACKBONE.FREEZE_AT = 0 - # cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[4],[8],[16], [32], [64], [64], [64]] - # cfg.MODEL.RPN.IN_FEATURES = ['p2', 'p2', 'p2', 'p3','p4','p5','p6'] - cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.33, 0.5, 1.0, 2.0, 3.0]] + cfg.SOLVER.MAX_ITER = args.train_steps + cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR + cfg.SOLVER.LRS = [0.000025,0.0000025] + cfg.SOLVER.STEPS = [70000,90000] + cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32],[64],[128], [256], [512], [1024]] + cfg.MODEL.RPN.IN_FEATURES = ['p2', 'p3', 'p4', 'p5','p6','p6'] + cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[.1,.2,0.33, 0.5, 1.0, 2.0, 3.0,5,10]] cfg.MODEL.ANCHOR_GENERATOR.ANGLES=[-90,-60,-30,0,30,60,90] - - cfg.MODEL.RPN.POSITIVE_FRACTION = 0.75 - cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(tc) cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES =len(sc) - - - # cfg.INPUT.CROP.ENABLED = True - # cfg.INPUT.CROP.TYPE='absolute' - # cfg.INPUT.CROP.SIZE=[64,64] - - cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256 # faster, and good enough for this toy dataset (default: 512) - # cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets) - + cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64 cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS=False - cfg.INPUT.MIN_SIZE_TRAIN=0 - # cfg.INPUT.MAX_SIZE_TRAIN=500 - - # exit() - os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) + cfg.INPUT.MIN_SIZE_TRAIN=args.boxSize + cfg.INPUT.MAX_SIZE_TRAIN=args.boxSize + cfg.DATASETS.TEST = () + cfg.OUTPUT_DIR = args.training_data_dir+"/output" + #os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) with open(cfg.OUTPUT_DIR+"/config_record.yaml", "w") as f: f.write(cfg.dump()) # save config to file - trainer = DefaultTrainer(cfg) + trainer = Trainer(cfg) + trainer.resume_or_load(resume=False) trainer.train() - cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained - cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.01 # set a custom testing threshold - cfg.TEST.DETECTIONS_PER_IMAGE = 500 - # - # cfg.INPUT.MIN_SIZE_TRAIN=64 - # cfg.INPUT.MAX_SIZE_TRAIN=4000 - cfg.INPUT.MIN_SIZE_TEST=64 - cfg.INPUT.MAX_SIZE_TEST=500 - - - predict_samples=100 - predictor = DefaultPredictor(cfg) - - dataset_dicts = samples_from_json_mini(json_file,predict_samples) - iter=0 - if not os.path.exists(os.getcwd()+'/network_predictions/'): - os.mkdir(os.getcwd()+'/network_predictions/') - for d in random.sample(dataset_dicts, predict_samples): - # print(d["file_name"]) - # imclass=d["file_name"].split('/')[-1].split('_')[-5].split(' ')[-1] - # if imclass in ["TRI","HE"]: - im = cv2.imread(d["file_name"]) - panoptic_seg, segments_info = predictor(im)["panoptic_seg"] # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format - # print(segments_info) - # plt.imshow(panoptic_seg.to("cpu")) - # plt.show() - v = Visualizer(im[:, :, ::-1], seg_metadata, scale=1.2) - v = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info) - # panoptic segmentation result - # plt.ion() - plt.subplot(121) - plt.imshow(im[:, :, ::-1]) - plt.subplot(122) - plt.imshow(v.get_image()) - plt.savefig(f"./network_predictions/input_{iter}.jpg",dpi=300) - plt.show() - # plt.ioff() - - - # v = Visualizer(im[:, :, ::-1], - # metadata=seg_metadata, - # scale=0.5, - # ) - # out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"),segments_info) - - # imsave('./network_predictions/pred'+str(iter)+'.png',np.hstack((im,v.get_image()))) - iter=iter+1 - # cv2.imshow('',out.get_image()[:, :, ::-1]) - # cv2.waitKey(0) # waits until a key is pressed - # cv2.destroyAllWindows() - #----------------------------------------------------------------------------------------- - - finish_model_generation(dirs,currentAnnotationIteration) - - print('\n\n\033[92;5mPlease place new wsi file(s) in: \n\t' + dirs['basedir'] + dirs['project']+ dirs['training_data_dir'] + str(currentAnnotationIteration+1)) - print('\nthen run [--option predict]\033[0m\n') - - - - -def moveimages(startfolder,endfolder): - filelist=glob.glob(startfolder + '*') - for file in filelist: - fileID=file.split('/')[-1] - move(file,endfolder + fileID) - - -def check_model_generation(dirs): - modelsCurrent=os.listdir(dirs['basedir'] + dirs['project'] + dirs['modeldir']) - gens=map(int,modelsCurrent) - modelOrder=np.sort(gens)[::-1] - - for idx in modelOrder: - #modelsChkptsLR=glob.glob(dirs['basedir'] + dirs['project'] + dirs['modeldir']+str(modelsCurrent[idx]) + '/LR/*.ckpt*') - modelsChkptsHR=glob.glob(dirs['basedir'] + dirs['project'] + dirs['modeldir']+ str(idx) +'/HR/*.ckpt*') - if modelsChkptsHR == []: - continue - else: - return idx - break - -def finish_model_generation(dirs,currentAnnotationIteration): - make_folder(dirs['basedir'] + dirs['project'] + dirs['training_data_dir'] + str(currentAnnotationIteration + 1)) - -def get_pretrain(currentAnnotationIteration,res,dirs): - - if currentAnnotationIteration==0: - pretrain_file = glob.glob(dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + '*') - pretrain_file=pretrain_file[0].split('.')[0] + '.' + pretrain_file[0].split('.')[1] - - else: - pretrains=glob.glob(dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + 'model*') - maxmodel=0 - for modelfiles in pretrains: - modelID=modelfiles.split('.')[-2].split('-')[1] - if int(modelID)>maxmodel: - maxmodel=int(modelID) - pretrain_file=dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + 'model.ckpt-' + str(maxmodel) - return pretrain_file - -def restart_line(): # for printing chopped image labels in command line - sys.stdout.write('\r') - sys.stdout.flush() - -def file_len(fname): # get txt file length (number of lines) - with open(fname) as f: - for i, l in enumerate(f): - pass - return i + 1 - -def make_folder(directory): - if not os.path.exists(directory): - os.makedirs(directory) # make directory if it does not exit already # make new directory # Check if folder exists, if not make it - -def make_all_folders(dirs): - - - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/regions') - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/masks') - - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/Augment' +'/regions') - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/Augment' +'/masks') - - make_folder(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/regions') - make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirHR'] + '/masks') - - make_folder(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/Augment' +'/regions') - make_folder(dirs['basedir']+dirs['project']+ dirs['tempdirHR'] + '/Augment' +'/masks') - - make_folder(dirs['basedir'] +dirs['project']+ dirs['modeldir']) - make_folder(dirs['basedir'] +dirs['project']+ dirs['training_data_dir']) - - - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/LR/'+ 'regions/') - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/LR/'+ 'masks/') - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/HR/'+ 'regions/') - make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/HR/'+ 'masks/') - - make_folder(dirs['basedir'] +dirs['project']+ dirs['training_data_dir']) - - make_folder(dirs['basedir'] + '/Codes/Deeplab_network/datasetLR') - make_folder(dirs['basedir'] + '/Codes/Deeplab_network/datasetHR') - -def return_region(args, wsi_mask, wsiID, fileID, yStart, xStart, idxy, idxx, downsampleRate, outdirT, region_size, dirs, chop_regions,classNum_HR): # perform cutting in parallel - sys.stdout.write(' <'+str(xStart)+'/'+ str(yStart)+'/'+str(chop_regions[idxy,idxx] != 0)+ '> ') - sys.stdout.flush() - restart_line() - - if chop_regions[idxy,idxx] != 0: - - uniqID=fileID + str(yStart) + str(xStart) - if wsiID.split('.')[-1] != 'tif': - slide=getWsi(wsiID) - Im=np.array(slide.read_region((xStart,yStart),0,(region_size,region_size))) - Im=Im[:,:,:3] - else: - yEnd = yStart + region_size - xEnd = xStart + region_size - Im = np.zeros([region_size,region_size,3], dtype=np.uint8) - Im_ = imread(wsiID)[yStart:yEnd, xStart:xEnd, :3] - Im[0:Im_.shape[0], 0:Im_.shape[1], :] = Im_ - - mask_annotation=wsi_mask[yStart:yStart+region_size,xStart:xStart+region_size] - - o1,o2=mask_annotation.shape - if o1 !=region_size: - mask_annotation=np.pad(mask_annotation,((0,region_size-o1),(0,0)),mode='constant') - if o2 !=region_size: - mask_annotation=np.pad(mask_annotation,((0,0),(0,region_size-o2)),mode='constant') - - ''' - if 4 in np.unique(mask_annotation): - plt.subplot(121) - plt.imshow(mask_annotation*20) - plt.subplot(122) - plt.imshow(Im) - pt=[xStart,yStart] - plt.title(pt) - plt.show() - ''' - if downsampleRate !=1: - c=(Im.shape) - s1=int(c[0]/(downsampleRate**.5)) - s2=int(c[1]/(downsampleRate**.5)) - Im=resize(Im,(s1,s2),mode='reflect') - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - imsave(outdirT + '/regions/' + uniqID + dirs['imExt'],Im) - imsave(outdirT + '/masks/' + uniqID +dirs['maskExt'],mask_annotation) - - -def regions_in_mask(root, bounds, verbose=1): - # find regions to save - IDs_reg = [] - IDs_points = [] - - for Annotation in root.findall("./Annotation"): # for all annotations - annotationID = Annotation.attrib['Id'] - annotationType = Annotation.attrib['Type'] - - # print(Annotation.findall(./)) - if annotationType =='9': - for element in Annotation.iter('InputAnnotationId'): - pointAnnotationID=element.text - - for Region in Annotation.findall("./*/Region"): # iterate on all region - - for Vertex in Region.findall("./*/Vertex"): # iterate on all vertex in region - # get points - x_point = np.int32(np.float64(Vertex.attrib['X'])) - y_point = np.int32(np.float64(Vertex.attrib['Y'])) - # test if points are in bounds - if bounds['x_min'] <= x_point <= bounds['x_max'] and bounds['y_min'] <= y_point <= bounds['y_max']: # test points in region bounds - # save region Id - IDs_points.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID,'pointAnnotationID':pointAnnotationID}) - break - elif annotationType=='4': - - for Region in Annotation.findall("./*/Region"): # iterate on all region - - for Vertex in Region.findall("./*/Vertex"): # iterate on all vertex in region - # get points - x_point = np.int32(np.float64(Vertex.attrib['X'])) - y_point = np.int32(np.float64(Vertex.attrib['Y'])) - # test if points are in bounds - if bounds['x_min'] <= x_point <= bounds['x_max'] and bounds['y_min'] <= y_point <= bounds['y_max']: # test points in region bounds - # save region Id - IDs_reg.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID}) - break - return IDs_reg,IDs_points - - -def get_vertex_points(root, IDs_reg,IDs_points, maskModes,excludedIDs,negativeIDs=None): - Regions = [] - Points = [] - - for ID in IDs_reg: - Vertices = [] - if ID['annotationID'] not in excludedIDs: - for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"): - Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))]) - Regions.append({'Vertices':np.array(Vertices),'annotationID':ID['annotationID']}) - - for ID in IDs_points: - Vertices = [] - for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"): - Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))]) - Points.append({'Vertices':np.array(Vertices),'pointAnnotationID':ID['pointAnnotationID']}) - if 'falsepositive' or 'negative' in maskModes: - assert negativeIDs is not None,'Negatively annotated classes must be provided for negative/falsepositive mask mode' - assert 'falsepositive' and 'negative' not in maskModes, 'Negative and false positive mask modes cannot both be true' - - useableRegions=[] - if 'positive' in maskModes: - for Region in Regions: - regionPath=path.Path(Region['Vertices']) - for Point in Points: - if Region['annotationID'] not in negativeIDs: - if regionPath.contains_point(Point['Vertices'][0]): - Region['pointAnnotationID']=Point['pointAnnotationID'] - useableRegions.append(Region) - - if 'negative' in maskModes: - - for Region in Regions: - regionPath=path.Path(Region['Vertices']) - if Region['annotationID'] in negativeIDs: - if not any([regionPath.contains_point(Point['Vertices'][0]) for Point in Points]): - Region['pointAnnotationID']=Region['annotationID'] - useableRegions.append(Region) - if 'falsepositive' in maskModes: - - for Region in Regions: - regionPath=path.Path(Region['Vertices']) - if Region['annotationID'] in negativeIDs: - if not any([regionPath.contains_point(Point['Vertices'][0]) for Point in Points]): - Region['pointAnnotationID']=0 - useableRegions.append(Region) - - return useableRegions -def chop_suey_bounds(lb,xmlID,box_supervision_layers,wsiID,dirs,args): - tree = ET.parse(xmlID) - root = tree.getroot() - lbVerts=np.array(lb['BoxVerts']) - xMin=min(lbVerts[:,0]) - xMax=max(lbVerts[:,0]) - yMin=min(lbVerts[:,1]) - yMax=max(lbVerts[:,1]) - - # test=np.array(slide.read_region((xMin,yMin),0,(xMax-xMin,yMax-yMin)))[:,:,:3] - - local_bound = {'x_min' : xMin, 'y_min' : yMin, 'x_max' : xMax, 'y_max' : yMax} - IDs_reg,IDs_points = regions_in_mask_dots(root=root, bounds=local_bound,box_layers=box_supervision_layers) - - # find regions in bounds - negativeIDs=['4'] - excludedIDs=['1'] - falsepositiveIDs=['4'] - usableRegions= get_vertex_points_dots(root=root, IDs_reg=IDs_reg,IDs_points=IDs_points,excludedIDs=excludedIDs,maskModes=['falsepositive','positive'],negativeIDs=negativeIDs, - falsepositiveIDs=falsepositiveIDs) - - # image_sizes= - masks_from_points(usableRegions,wsiID,dirs,50,args,[xMin,xMax,yMin,yMax]) -''' -def masks_from_points(root,usableRegions,wsiID,dirs): - pas_img = getWsi(wsiID) - image_sizes=[] - basename=wsiID.split('/')[-1].split('.svs')[0] - - for usableRegion in tqdm(usableRegions): - vertices=usableRegion['Vertices'] - x1=min(vertices[:,0]) - x2=max(vertices[:,0]) - y1=min(vertices[:,1]) - y2=max(vertices[:,1]) - points = np.stack([np.asarray(vertices[:,0]), np.asarray(vertices[:,1])], axis=1) - if (x2-x1)>0 and (y2-y1)>0: - l1=x2-x1 - l2=y2-y1 - xMultiplier=np.ceil((l1)/64) - yMultiplier=np.ceil((l2)/64) - pad1=int(xMultiplier*64-l1) - pad2=int(yMultiplier*64-l2) - - points[:,1] = np.int32(np.round(points[:,1] - y1 )) - points[:,0] = np.int32(np.round(points[:,0] - x1 )) - mask = 2*np.ones([y2-y1,x2-x1], dtype=np.uint8) - if int(usableRegion['pointAnnotationID'])==0: - pass - else: - cv2.fillPoly(mask, [points], int(usableRegion['pointAnnotationID'])-4) - PAS = pas_img.read_region((x1,y1), 0, (x2-x1,y2-y1)) - # print(usableRegion['pointAnnotationID']) - PAS = np.array(PAS)[:,:,0:3] - mask=np.pad( mask,((0,pad2),(0,pad1)),'constant',constant_values=(2,2) ) - PAS=np.pad( PAS,((0,pad2),(0,pad1),(0,0)),'constant',constant_values=(0,0) ) - - image_identifier=basename+'_'.join(['',str(x1),str(y1),str(l1),str(l2)]) - mask_out_name=dirs['basedir']+dirs['project'] + '/Permanent/HR/masks/'+image_identifier+'.png' - image_out_name=mask_out_name.replace('/masks/','/regions/') - # basename + '_' + str(image_identifier) + args.imBoxExt - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - imsave(image_out_name,PAS) - imsave(mask_out_name,mask) - # exit() - # extract image region - # plt.subplot(121) - # plt.imshow(PAS) - # plt.subplot(122) - # plt.imshow(mask) - # plt.show() - # image_sizes.append([x2-x1,y2-y1]) + _ = os.system("printf '\nTraining completed!\n'") + +def mask2polygons(mask): + annotation=[] + presentclasses=np.unique(mask) + offset=-3 + presentclasses=presentclasses[presentclasses>2] + presentclasses=list(presentclasses[presentclasses<7]) + for p in presentclasses: + contours, hierarchy = cv2.findContours(np.array(mask==p, dtype='uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + for contour in contours: + if contour.size>=6: + instance_dict={} + contour_flat=contour.flatten().astype('float').tolist() + xMin=min(contour_flat[::2]) + yMin=min(contour_flat[1::2]) + xMax=max(contour_flat[::2]) + yMax=max(contour_flat[1::2]) + instance_dict['bbox']=[xMin,yMin,xMax,yMax] + instance_dict['bbox_mode']=BoxMode.XYXY_ABS.value + instance_dict['category_id']=p+offset + instance_dict['segmentation']=[contour_flat] + annotation.append(instance_dict) + return annotation + + +class Trainer(DefaultTrainer): + + @classmethod + def build_test_loader(cls, cfg, dataset_name): + return build_detection_test_loader(cfg, dataset_name, mapper=CustomDatasetMapper(cfg, True)) + + @classmethod + def build_train_loader(cls, cfg): + return build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True)) + + +class CustomDatasetMapper: + + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + use_instance_mask: bool = False, + use_keypoint: bool = False, + instance_mask_format: str = "polygon", + keypoint_hflip_indices: Optional[np.ndarray] = None, + precomputed_proposal_topk: Optional[int] = None, + recompute_boxes: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + use_instance_mask: whether to process instance segmentation annotations, if available + use_keypoint: whether to process keypoint annotations if available + instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation + masks into this format. + keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` + precomputed_proposal_topk: if given, will load pre-computed + proposals from dataset_dict and keep the top k proposals for each image. + recompute_boxes: whether to overwrite bounding box annotations + by computing tight bounding boxes from instance mask annotations. + """ + if recompute_boxes: + assert use_instance_mask, "recompute_boxes requires instance masks" + # fmt: off + self.is_train = is_train + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + self.use_instance_mask = use_instance_mask + self.instance_mask_format = instance_mask_format + self.use_keypoint = use_keypoint + self.keypoint_hflip_indices = keypoint_hflip_indices + self.proposal_topk = precomputed_proposal_topk + self.recompute_boxes = recompute_boxes + # fmt: on + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + recompute_boxes = cfg.MODEL.MASK_ON else: - print('Broken region') - return image_sizes -''' + recompute_boxes = False + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "use_instance_mask": cfg.MODEL.MASK_ON, + "instance_mask_format": cfg.INPUT.MASK_FORMAT, + "use_keypoint": cfg.MODEL.KEYPOINT_ON, + "recompute_boxes": recompute_boxes, + } + + if cfg.MODEL.KEYPOINT_ON: + ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + + if cfg.MODEL.LOAD_PROPOSALS: + ret["precomputed_proposal_topk"] = ( + cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN + if is_train + else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST + ) + return ret + + def _transform_annotations(self, dataset_dict, transforms, image_shape): + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.use_instance_mask: + anno.pop("segmentation", None) + if not self.use_keypoint: + anno.pop("keypoints", None) + + annos = [ + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ) + for obj in dataset_dict.pop('annotations') + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format + ) + + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + c=dataset_dict['coordinates'] + h=dataset_dict['height'] + w=dataset_dict['width'] + + slide= TiffSlide(dataset_dict['slide_loc']) + image=np.array(slide.read_region((c[0],c[1]),0,(h,w)))[:,:,:3] + slide.close() + maskData=xml_to_mask(dataset_dict['xml_loc'], c, [h,w]) + + if random.random()>0.5: + hShift=np.random.normal(0,0.05) + lShift=np.random.normal(1,0.025) + # imageblock[im]=randomHSVshift(imageblock[im],hShift,lShift) + image=rgb2hsv(image) + image[:,:,0]=(image[:,:,0]+hShift) + image=hsv2rgb(image) + image=rgb2lab(image) + image[:,:,0]=exposure.adjust_gamma(image[:,:,0],lShift) + image=(lab2rgb(image)*255).astype('uint8') + image = seq(images=[image])[0].squeeze() + + dataset_dict['annotations']=mask2polygons(maskData) + utils.check_image_size(dataset_dict, image) + + sem_seg_gt = maskData + sem_seg_gt[sem_seg_gt>2]=0 + sem_seg_gt[maskData==0] = 3 + sem_seg_gt=np.array(sem_seg_gt).astype('uint8') + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + transforms = self.augmentations(aug_input) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] # h, w + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) + + if "annotations" in dataset_dict: + self._transform_annotations(dataset_dict, transforms, image_shape) + + + return dataset_dict diff --git a/multic/segmentationschool/Codes/wsi_loader_utils.py b/multic/segmentationschool/Codes/wsi_loader_utils.py index ea87ae6..0e1edb4 100644 --- a/multic/segmentationschool/Codes/wsi_loader_utils.py +++ b/multic/segmentationschool/Codes/wsi_loader_utils.py @@ -1,85 +1,85 @@ -import openslide,glob,os +import glob, os import numpy as np from scipy.ndimage.morphology import binary_fill_holes -import matplotlib.pyplot as plt from skimage.color import rgb2hsv from skimage.filters import gaussian -# from skimage.morphology import binary_dilation, diamond -# import cv2 from tqdm import tqdm from skimage.io import imread,imsave import multiprocessing from joblib import Parallel, delayed - -def save_thumb(args,slide_loc): - print(slide_loc) - slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1]) - slide=openslide.OpenSlide(slide_loc) - if slideExt =='.scn': - dim_x=int(slide.properties['openslide.bounds-width'])## add to columns - dim_y=int(slide.properties['openslide.bounds-height'])## add to rows - offsetx=int(slide.properties['openslide.bounds-x'])##start column - offsety=int(slide.properties['openslide.bounds-y'])##start row - elif slideExt in ['.ndpi','.svs']: - dim_x, dim_y=slide.dimensions - offsetx=0 - offsety=0 - - # fullSize=slide.level_dimensions[0] - # resRatio= args.chop_thumbnail_resolution - # ds_1=fullSize[0]/resRatio - # ds_2=fullSize[1]/resRatio - # thumbIm=np.array(slide.get_thumbnail((ds_1,ds_2))) - # if slideExt =='.scn': - # xStt=int(offsetx/resRatio) - # xStp=int((offsetx+dim_x)/resRatio) - # yStt=int(offsety/resRatio) - # yStp=int((offsety+dim_y)/resRatio) - # thumbIm=thumbIm[yStt:yStp,xStt:xStp] - # imsave(slide_loc.replace(slideExt,'_thumb.jpeg'),thumbIm) - slide.associated_images['label'].save(slide_loc.replace(slideExt,'_label.png')) - # imsave(slide_loc.replace(slideExt,'_label.png'),slide.associated_images['label']) - - -def get_image_thumbnails(args): - assert args.target is not None, 'Location of images must be provided' - all_slides=[] - for ext in args.wsi_ext.split(','): - all_slides.extend(glob.glob(args.target+'/*'+ext)) - Parallel(n_jobs=multiprocessing.cpu_count())(delayed(save_thumb)(args,slide_loc) for slide_loc in tqdm(all_slides)) - # for slide_loc in tqdm(all_slides): - -class WSIPredictLoader(): - def __init__(self,args, wsi_directory=None, transform=None): +from shapely.geometry import Polygon +from tiffslide import TiffSlide +import random +import glob +import warnings +from joblib import Parallel, delayed +import multiprocessing +from .xml_to_mask_minmax import write_minmax_to_xml +import lxml.etree as ET + +def get_image_meta(i,args): + image_annotation_info={} + image_annotation_info['slide_loc']=i[0] + slide=TiffSlide(image_annotation_info['slide_loc']) + magx=np.round(float(slide.properties['tiffslide.mpp-x']),2) + magy=np.round(float(slide.properties['tiffslide.mpp-y']),2) + + assert magx == magy + if magx ==0.25: + dx=args.boxSize + dy=args.boxSize + elif magx == 0.5: + dx=int(args.boxSize/2) + dy=int(args.boxSize/2) + else: + print('nonstandard image magnification') + print(slide) + print(magx,magy) + exit() + + image_annotation_info['coordinates']=[i[2][1],i[2][0]] + image_annotation_info['height']=dx + image_annotation_info['width']=dy + image_annotation_info['image_id']=i[1].split('/')[-1].replace('.xml','_'.join(['',str(i[2][1]),str(i[2][0])])) + image_annotation_info['xml_loc']=i[1] + image_annotation_info['file_name']=i[1].split('/')[-1] + slide.close() + return image_annotation_info + +def train_samples_from_WSI(args,image_coordinates): + + + num_cores=multiprocessing.cpu_count() + print('Generating detectron2 dictionary format...') + data_list=Parallel(n_jobs=num_cores,backend='threading')(delayed(get_image_meta)(i=i, + args=args) for i in tqdm(image_coordinates)) + return data_list + +def WSIGridIterator(wsi_name,choppable_regions,index_x,index_y,region_size,dim_x,dim_y): + wsi_name=os.path.splitext(wsi_name.split('/')[-1])[0] + data_list=[] + for idxy, i in tqdm(enumerate(index_y)): + for idxx, j in enumerate(index_x): + if choppable_regions[idxy, idxx] != 0: + yEnd = min(dim_y,i+region_size) + xEnd = min(dim_x,j+region_size) + xLen=xEnd-j + yLen=yEnd-i + + image_annotation_info={} + image_annotation_info['file_name']='_'.join([wsi_name,str(j),str(i),str(xEnd),str(yEnd)]) + image_annotation_info['height']=yLen + image_annotation_info['width']=xLen + image_annotation_info['image_id']=image_annotation_info['file_name'] + image_annotation_info['xStart']=j + image_annotation_info['yStart']=i + data_list.append(image_annotation_info) + return data_list + +def get_slide_data(args, wsi_directory=None): assert wsi_directory is not None, 'location of training svs and xml must be provided' - mask_out_loc=os.path.join(wsi_directory.replace('/TRAINING_data/0','Permanent/Tissue_masks/'),) - if not os.path.exists(mask_out_loc): - os.makedirs(mask_out_loc) - all_slides=[] - for ext in args.wsi_ext.split(','): - all_slides.extend(glob.glob(wsi_directory+'/*'+ext)) - print('Getting slide metadata and usable regions...') - usable_slides=[] - for slide_loc in all_slides: - slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1]) - print("working slide... "+ slideID,end='\r') - - slide=openslide.OpenSlide(slide_loc) - chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc) - mag_x=np.round(float(slide.properties['openslide.mpp-x']),2) - mag_y=np.round(float(slide.properties['openslide.mpp-y']),2) - print(mag_x,mag_y) - usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,'slideExt':slideExt,'slide':slide, - 'chop_array':chop_array,'mag':[mag_x,mag_y]}) - self.usable_slides= usable_slides - self.boxSize40X = args.boxSize - self.boxSize20X = int(args.boxSize)/2 - -class WSITrainingLoader(): - def __init__(self,args, wsi_directory=None, transform=None): - assert wsi_directory is not None, 'location of training svs and xml must be provided' - mask_out_loc=os.path.join(wsi_directory.replace('/TRAINING_data/0','Permanent/Tissue_masks/'),) + mask_out_loc=os.path.join(wsi_directory, 'Tissue_masks') if not os.path.exists(mask_out_loc): os.makedirs(mask_out_loc) all_slides=[] @@ -90,34 +90,102 @@ def __init__(self,args, wsi_directory=None, transform=None): usable_slides=[] for slide_loc in all_slides: slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1]) - print("working slide... "+ slideID,end='\r') - - slide=openslide.OpenSlide(slide_loc) - chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc) - mag_x=np.round(float(slide.properties['openslide.mpp-x']),2) - mag_y=np.round(float(slide.properties['openslide.mpp-y']),2) - print(mag_x,mag_y) - usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,'slideExt':slideExt,'slide':slide, - 'chop_array':chop_array,'mag':[mag_x,mag_y]}) - self.usable_slides= usable_slides - self.boxSize40X = args.boxSize - self.boxSize20X = int(args.boxSize)/2 + xmlpath=slide_loc.replace(slideExt,'.xml') + if os.path.isfile(xmlpath): + write_minmax_to_xml(xmlpath) + + print("Gathering slide data ... "+ slideID,end='\r') + slide =TiffSlide(slide_loc) + chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc) + + mag_x=np.round(float(slide.properties['tiffslide.mpp-x']),2) + mag_y=np.round(float(slide.properties['tiffslide.mpp-y']),2) + slide.close() + tree = ET.parse(xmlpath) + root = tree.getroot() + balance_classes=args.balanceClasses.split(',') + classNums={} + for b in balance_classes: + classNums[b]=0 + # balance_annotations={} + for Annotation in root.findall("./Annotation"): + + annotationID = Annotation.attrib['Id'] + if annotationID=='7': + print(xmlpath) + exit() + if annotationID in classNums.keys(): + + classNums[annotationID]=len(Annotation.findall("./*/Region")) + else: + pass + + usable_slides.append({'slide_loc':slide_loc,'slideID':slideID, + 'chop_array':chop_array,'num_regions':len(chop_array),'mag':[mag_x,mag_y], + 'xml_loc':xmlpath,'annotations':classNums,'root':root + }) + else: + print('\n') + print('no annotation XML file found for:') + print(slideID) + exit() print('\n') + return usable_slides + +def get_random_chops(slide_idx,usable_slides,region_size): + # chops=[] + choplen=len(slide_idx) + chops=Parallel(n_jobs=multiprocessing.cpu_count(),backend='threading')(delayed(get_chop_data)(idx=idx, + usable_slides=usable_slides,region_size=region_size) for idx in tqdm(slide_idx)) + return chops + + +def get_chop_data(idx,usable_slides,region_size): + if random.random()>0.5: + randSelect=random.randrange(0,usable_slides[idx]['num_regions']) + chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'], + usable_slides[idx]['chop_array'][randSelect]] + else: + # print(list(usable_slides[idx]['annotations'].values())) + if sum(usable_slides[idx]['annotations'].values())==0: + randSelect=random.randrange(0,usable_slides[idx]['num_regions']) + chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'], + usable_slides[idx]['chop_array'][randSelect]] + else: + classIDs=list(usable_slides[idx]['annotations'].keys()) + classSamples=random.sample(classIDs,len(classIDs)) + for c in classSamples: + if usable_slides[idx]['annotations'][c]==0 or c == '5': + continue + else: + sampledRegionID=random.randrange(1,usable_slides[idx]['annotations'][c]+1) + + break + + + Verts = usable_slides[idx]['root'].findall("./Annotation[@Id='{}']/Regions/Region[@Id='{}']/Vertices/Vertex".format(c,sampledRegionID)) + centroid = (Polygon([(int(float(k.attrib['X'])),int(float(k.attrib['Y']))) for k in Verts]).centroid) + + randVertX=int(centroid.x)-region_size//2 + randVertY=int(centroid.y)-region_size//2 + + chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'], + [randVertY,randVertX]] + + return chopData def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): slide_regions=[] choppable_regions_list=[] - downsample = int(args.downsampleRate**.5) #down sample for each dimension region_size = int(args.boxSize*(downsample)) #Region size before downsampling - step = int(region_size*(1-args.overlap_percent)) #Step size before downsampling - + step = int(region_size*(1-args.overlap_rate)) #Step size before downsampling if slideExt =='.scn': - dim_x=int(slide.properties['openslide.bounds-width'])## add to columns - dim_y=int(slide.properties['openslide.bounds-height'])## add to rows - offsetx=int(slide.properties['openslide.bounds-x'])##start column - offsety=int(slide.properties['openslide.bounds-y'])##start row + dim_x=int(slide.properties['tiffslide.bounds-width'])## add to columns + dim_y=int(slide.properties['tiffslide.bounds-height'])## add to rows + offsetx=int(slide.properties['tiffslide.bounds-x'])##start column + offsety=int(slide.properties['tiffslide.bounds-y'])##start row index_y=np.array(range(offsety,offsety+dim_y,step)) index_x=np.array(range(offsetx,offsetx+dim_x,step)) index_y[-1]=(offsety+dim_y)-step @@ -135,7 +203,9 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): resRatio= args.chop_thumbnail_resolution ds_1=fullSize[0]/resRatio ds_2=fullSize[1]/resRatio - if args.get_new_tissue_masks: + out_mask_name=os.path.join(mask_out_loc,'_'.join([slideID,slideExt[1:]+'.png'])) + if not os.path.isfile(out_mask_name) or args.get_new_tissue_masks: + print(out_mask_name) thumbIm=np.array(slide.get_thumbnail((ds_1,ds_2))) if slideExt =='.scn': xStt=int(offsetx/resRatio) @@ -143,40 +213,18 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): yStt=int(offsety/resRatio) yStp=int((offsety+dim_y)/resRatio) thumbIm=thumbIm[yStt:yStp,xStt:xStp] - # plt.imshow(thumbIm) - # plt.show() - # input() - # plt.imshow(thumbIm) - # plt.show() - - out_mask_name=os.path.join(mask_out_loc,'_'.join([slideID,slideExt[1:]+'.png'])) - - - if not args.get_new_tissue_masks: - try: - binary=(imread(out_mask_name)/255).astype('bool') - except: - print('failed to load mask for '+ out_mask_name) - print('please set get_new_tissue masks to True') - exit() - # if slideExt =='.scn': - # choppable_regions=np.zeros((len(index_x),len(index_y))) - # elif slideExt in ['.ndpi','.svs']: - choppable_regions=np.zeros((len(index_y),len(index_x))) - else: - print(out_mask_name) - # if slideExt =='.scn': - # choppable_regions=np.zeros((len(index_x),len(index_y))) - # elif slideExt in ['.ndpi','.svs']: choppable_regions=np.zeros((len(index_y),len(index_x))) - hsv=rgb2hsv(thumbIm) g=gaussian(hsv[:,:,1],5) binary=(g>0.05).astype('bool') binary=binary_fill_holes(binary) imsave(out_mask_name.replace('.png','.jpeg'),thumbIm) - imsave(out_mask_name,binary.astype('uint8')*255) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + imsave(out_mask_name,binary.astype('uint8')*255) + binary=(imread(out_mask_name)/255).astype('bool') + choppable_regions=np.zeros((len(index_y),len(index_x))) chop_list=[] for idxy,yi in enumerate(index_y): for idxx,xj in enumerate(index_x): @@ -185,31 +233,13 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc): xStart = int(np.round((xj-offsetx)/resRatio)) xStop = int(np.round(((xj-offsetx)+args.boxSize)/resRatio)) box_total=(xStop-xStart)*(yStop-yStart) - if slideExt =='.scn': - # print(xStart,xStop,yStart,yStop) - # print(np.sum(binary[xStart:xStop,yStart:yStop]),args.white_percent,box_total) - # plt.imshow(binary[xStart:xStop,yStart:yStop]) - # plt.show() - if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total): - - choppable_regions[idxy,idxx]=1 - chop_list.append([index_y[idxy],index_x[idxx]]) - - elif slideExt in ['.ndpi','.svs']: - if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total): - choppable_regions[idxy,idxx]=1 - chop_list.append([index_y[idxy],index_x[idxx]]) - - imsave(out_mask_name.replace('.png','_chopregions.png'),choppable_regions.astype('uint8')*255) - - # plt.imshow(choppable_regions) - # plt.show() - # choppable_regions_list.extend(chop_list) - # plt.subplot(131) - # plt.imshow(thumbIm) - # plt.subplot(132) - # plt.imshow(binary) - # plt.subplot(133) - # plt.imshow(choppable_regions) - # plt.show() + + if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total): + choppable_regions[idxy,idxx]=1 + chop_list.append([index_y[idxy],index_x[idxx]]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + imsave(out_mask_name.replace('.png','_chopregions.png'),choppable_regions.astype('uint8')*255) + return chop_list diff --git a/multic/segmentationschool/segmentation_school.py b/multic/segmentationschool/segmentation_school.py index 50e32c1..a0d5b42 100644 --- a/multic/segmentationschool/segmentation_school.py +++ b/multic/segmentationschool/segmentation_school.py @@ -174,7 +174,7 @@ def savetime(args, starttime): ##### Args for training / prediction #################################################### parser.add_argument('--gpu_num', dest='gpu_num', default=2 ,type=int, help='number of GPUs avalable') - parser.add_argument('--gpu', dest='gpu', default=0 ,type=int, + parser.add_argument('--gpu', dest='gpu', default="" ,type=str, help='GPU to use for prediction') parser.add_argument('--iteration', dest='iteration', default='none' ,type=str, help='Which iteration to use for prediction') @@ -188,6 +188,18 @@ def savetime(args, starttime): help='number of classes present in the High res training data [USE ONLY IF DIFFERENT FROM LOW RES]') parser.add_argument('--modelfile', dest='modelfile', default=None ,type=str, help='the desired model file to use for training or prediction') + parser.add_argument('--init_modelfile', dest='init_modelfile', default=None ,type=str, + help='the desired model file to use for training or prediction') + parser.add_argument('--eval_period', dest='eval_period', default=1000 ,type=int, + help='Validation Period') + parser.add_argument('--batch_size', dest='batch_size', default=4 ,type=int, + help='Size of batches for training high resolution CNN') + parser.add_argument('--train_steps', dest='train_steps', default=1000 ,type=int, + help='Size of batches for training high resolution CNN') + parser.add_argument('--training_data_dir', dest='training_data_dir', default=os.getcwd(),type=str, + help='Training Data Folder') + parser.add_argument('--overlap_rate', dest='overlap_rate', default=0.5 ,type=float, + help='overlap percentage of high resolution blocks [0-1]') ### Params for cutting wsi ### #White level cutoff @@ -202,10 +214,14 @@ def savetime(args, starttime): help='size of low resolution blocks') parser.add_argument('--downsampleRateLR', dest='downsampleRateLR', default=16 ,type=int, help='reduce image resolution to 1/downsample rate') + parser.add_argument('--get_new_tissue_masks', dest='get_new_tissue_masks', default=False,type=str2bool, + help="Don't load usable tisse regions from disk, create new ones") + parser.add_argument('--downsampleRate', dest='downsampleRate', default=1 ,type=int, + help='reduce image resolution to 1/downsample rate') #High resolution parameters parser.add_argument('--overlap_percentHR', dest='overlap_percentHR', default=0 ,type=float, help='overlap percentage of high resolution blocks [0-1]') - parser.add_argument('--boxSize', dest='boxSize', default=2048 ,type=int, + parser.add_argument('--boxSize', dest='boxSize', default=1200 ,type=int, help='size of high resolution blocks') parser.add_argument('--downsampleRateHR', dest='downsampleRateHR', default=1 ,type=int, help='reduce image resolution to 1/downsample rate') @@ -226,7 +242,8 @@ def savetime(args, starttime): help='Gaussian variance defining bounds on Hue shift for HSV color augmentation') parser.add_argument('--lbound', dest='lbound', default=0.025 ,type=float, help='Gaussian variance defining bounds on L* gamma shift for color augmentation [alters brightness/darkness of image]') - + parser.add_argument('--balanceClasses', dest='balanceClasses', default='3,4,5,6',type=str, + help="which classes to balance during training") ### Params for training networks ### #Low resolution hyperparameters parser.add_argument('--CNNbatch_sizeLR', dest='CNNbatch_sizeLR', default=2 ,type=int, @@ -280,6 +297,8 @@ def savetime(args, starttime): help='padded region for low resolution region extraction') parser.add_argument('--show_interstitium', dest='show_interstitium', default=True ,type=str2bool, help='padded region for low resolution region extraction') + parser.add_argument('--num_workers', dest='num_workers', default=1 ,type=int, + help='Number of workers for data loader') diff --git a/multic/segmentationschool/utils/mask_to_xml.py b/multic/segmentationschool/utils/mask_to_xml.py new file mode 100644 index 0000000..34217e5 --- /dev/null +++ b/multic/segmentationschool/utils/mask_to_xml.py @@ -0,0 +1,135 @@ +import cv2 +import numpy as np +import lxml.etree as ET + +""" +xml_path (string) - the filename of the saved xml +mask (array) - the mask to convert to xml - uint8 array +downsample (int) - amount of downsampling done to the mask + points are upsampled - this can be used to simplify the mask +min_size_thresh (int) - the minimum objectr size allowed in the mask. This is referenced from downsample=1 +xml_color (list) - list of binary color values to be used for classes + +""" + +def mask_to_xml(xml_path, mask, downsample=1, min_size_thresh=0, simplify_contours=0, xml_color=[65280, 65535, 33023, 255, 16711680], verbose=0, return_root=False, maxClass=None, offset={'X': 0,'Y': 0}): + + min_size_thresh /= downsample + + # create xml tree + Annotations = xml_create() + + # get all classes + classes = np.unique(mask) + if maxClass is None: + maxClass = max(classes) + + # add annotation classes to tree + for class_ in range(maxClass+1)[1:]: + if verbose: + print('Creating class: [{}]'.format(class_)) + Annotations = xml_add_annotation(Annotations=Annotations, xml_color=xml_color, annotationID=class_) + + # add contour points to tree classwise + for class_ in classes: # iterate through all classes + + if class_ == 0 or class_ > maxClass: + continue + + if verbose: + print('Working on class [{} of {}]'.format(class_, max(classes))) + + # binarize the mask w.r.t. class_ + binaryMask = mask==class_ + + # get contour points of the mask + pointsList = get_contour_points(binaryMask, downsample=downsample, min_size_thresh=min_size_thresh, simplify_contours=simplify_contours, offset=offset) + for i in range(np.shape(pointsList)[0]): + pointList = pointsList[i] + Annotations = xml_add_region(Annotations=Annotations, pointList=pointList, annotationID=class_) + + if return_root: + # return root, do not save xml file + return Annotations + + # save the final xml file + xml_save(Annotations=Annotations, filename='{}.xml'.format(xml_path.split('.')[0])) + + +def get_contour_points(mask, downsample, min_size_thresh=0, simplify_contours=0, offset={'X': 0,'Y': 0}): + # returns a dict pointList with point 'X' and 'Y' values + # input greyscale binary image + #_, maskPoints, contours = cv2.findContours(np.array(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS) + maskPoints, contours = cv2.findContours(np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1) + maskPoints = list(maskPoints) + # remove small regions + too_small = [] + for idx, cnt in enumerate(maskPoints): + area = cv2.contourArea(cnt) + if area < min_size_thresh: + too_small.append(idx) + if too_small != []: + too_small.reverse() + for idx in too_small: + maskPoints.pop(idx) + + if simplify_contours > 0: + for idx, cnt in enumerate(maskPoints): + epsilon = simplify_contours*cv2.arcLength(cnt,True) + approx = cv2.approxPolyDP(cnt,epsilon,True) + maskPoints[idx] = approx + + pointsList = [] + for j in range(np.shape(maskPoints)[0]): + pointList = [] + for i in range(0,np.shape(maskPoints[j])[0]): + point = {'X': (maskPoints[j][i][0][0] * downsample) + offset['X'], 'Y': (maskPoints[j][i][0][1] * downsample) + offset['Y']} + pointList.append(point) + pointsList.append(pointList) + return pointsList + +### functions for building an xml tree of annotations ### +def xml_create(): # create new xml tree + # create new xml Tree - Annotations + Annotations = ET.Element('Annotations') + return Annotations + +def xml_add_annotation(Annotations, xml_color, annotationID=None): # add new annotation + # add new Annotation to Annotations + # defualts to new annotationID + if annotationID == None: # not specified + annotationID = len(Annotations.findall('Annotation')) + 1 + Annotation = ET.SubElement(Annotations, 'Annotation', attrib={'Type': '4', 'Visible': '1', 'ReadOnly': '0', 'Incremental': '0', 'LineColorReadOnly': '0', 'LineColor': str(xml_color[annotationID-1]), 'Id': str(annotationID), 'NameReadOnly': '0'}) + Regions = ET.SubElement(Annotation, 'Regions') + return Annotations + +def xml_add_region(Annotations, pointList, annotationID=-1, regionID=None): # add new region to annotation + # add new Region to Annotation + # defualts to last annotationID and new regionID + Annotation = Annotations.find("Annotation[@Id='" + str(annotationID) + "']") + Regions = Annotation.find('Regions') + if regionID == None: # not specified + regionID = len(Regions.findall('Region')) + 1 + Region = ET.SubElement(Regions, 'Region', attrib={'NegativeROA': '0', 'ImageFocus': '-1', 'DisplayId': '1', 'InputRegionId': '0', 'Analyze': '0', 'Type': '0', 'Id': str(regionID)}) + Vertices = ET.SubElement(Region, 'Vertices') + for point in pointList: # add new Vertex + ET.SubElement(Vertices, 'Vertex', attrib={'X': str(point['X']), 'Y': str(point['Y']), 'Z': '0'}) + # add connecting point + ET.SubElement(Vertices, 'Vertex', attrib={'X': str(pointList[0]['X']), 'Y': str(pointList[0]['Y']), 'Z': '0'}) + return Annotations + +def xml_save(Annotations, filename): + xml_data = ET.tostring(Annotations, pretty_print=True) + #xml_data = Annotations.toprettyxml() + f = open(filename, 'w') + f.write(xml_data.decode()) + f.close() + +def read_xml(filename): + # import xml file + tree = ET.parse(filename) + root = tree.getroot() + +if __name__ == '__main__': + main() + \ No newline at end of file diff --git a/multic/segmentationschool/utils/xml_to_mask.py b/multic/segmentationschool/utils/xml_to_mask.py new file mode 100644 index 0000000..aeecd6a --- /dev/null +++ b/multic/segmentationschool/utils/xml_to_mask.py @@ -0,0 +1,219 @@ +import numpy as np +import lxml.etree as ET +import cv2 +import time +import os + +""" +location (tuple) - (x, y) tuple giving the top left pixel in the level 0 reference frame +size (tuple) - (width, height) tuple giving the region size | set to 'full' for entire mask +downsample - int giving the amount of downsampling done to the output pixel mask + +NOTE: if you plan to loop through xmls parallely, it is nessesary to run write_minmax_to_xml() + on all the files prior - to avoid conflicting file writes + +""" + +def xml_to_mask(xml_path, location, size, tree=None, downsample=1, verbose=0): + + # parse xml and get root + if tree == None: tree = ET.parse(xml_path) + root = tree.getroot() + + if size == 'full': + import math + size = write_minmax_to_xml(xml_path=xml_path, tree=tree, get_absolute_max=True) + size = (math.ceil(size[0]/downsample), math.ceil(size[1]/downsample)) + location = (0,0) + + # calculate region bounds + bounds = {'x_min' : location[0], 'y_min' : location[1], 'x_max' : location[0] + size[0]*downsample, 'y_max' : location[1] + size[1]*downsample} + + IDs = regions_in_mask(xml_path=xml_path, root=root, tree=tree, bounds=bounds, verbose=verbose) + + if verbose != 0: + print('\nFOUND: ' + str(len(IDs)) + ' regions') + + # find regions in bounds + Regions = get_vertex_points(root=root, IDs=IDs, verbose=verbose) + + # fill regions and create mask + mask = Regions_to_mask(Regions=Regions, bounds=bounds, IDs=IDs, downsample=downsample, verbose=verbose) + if verbose != 0: + print('done...\n') + + return mask + + +def regions_in_mask(xml_path, root, tree, bounds, verbose=1): + # find regions to save + IDs = [] + mtime = os.path.getmtime(xml_path) + + write_minmax_to_xml(xml_path, tree) + + for Annotation in root.findall("./Annotation"): # for all annotations + annotationID = Annotation.attrib['Id'] + + for Region in Annotation.findall("./*/Region"): # iterate on all region + + for Vert in Region.findall("./Vertices"): # iterate on all vertex in region + + # get minmax points + Xmin = np.int32(Vert.attrib['Xmin']) + Ymin = np.int32(Vert.attrib['Ymin']) + Xmax = np.int32(Vert.attrib['Xmax']) + Ymax = np.int32(Vert.attrib['Ymax']) + + # test minmax points in region bounds + if bounds['x_min'] <= Xmax and bounds['x_max'] >= Xmin and bounds['y_min'] <= Ymax and bounds['y_max'] >= Ymin: + # save region Id + IDs.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID}) + break + return IDs + +def get_vertex_points(root, IDs, verbose=1): + Regions = [] + + for ID in IDs: # for all IDs + + # get all vertex attributes (points) + Vertices = [] + + for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"): + # make array of points + Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))]) + + Regions.append(np.array(Vertices)) + + return Regions + +def Regions_to_mask(Regions, bounds, IDs, downsample, verbose=1): + # downsample = int(np.round(downsample_factor**(.5))) + + if verbose !=0: + print('\nMAKING MASK:') + + if len(Regions) != 0: # regions present + # get min/max sizes + min_sizes = np.empty(shape=[2,0], dtype=np.int32) + max_sizes = np.empty(shape=[2,0], dtype=np.int32) + for Region in Regions: # fill all regions + min_bounds = np.reshape((np.amin(Region, axis=0)), (2,1)) + max_bounds = np.reshape((np.amax(Region, axis=0)), (2,1)) + min_sizes = np.append(min_sizes, min_bounds, axis=1) + max_sizes = np.append(max_sizes, max_bounds, axis=1) + min_size = np.amin(min_sizes, axis=1) + max_size = np.amax(max_sizes, axis=1) + + # add to old bounds + bounds['x_min_pad'] = min(min_size[1], bounds['x_min']) + bounds['y_min_pad'] = min(min_size[0], bounds['y_min']) + bounds['x_max_pad'] = max(max_size[1], bounds['x_max']) + bounds['y_max_pad'] = max(max_size[0], bounds['y_max']) + + # make blank mask + mask = np.zeros([ int(np.round((bounds['y_max_pad'] - bounds['y_min_pad']) / downsample)), int(np.round((bounds['x_max_pad'] - bounds['x_min_pad']) / downsample)) ], dtype=np.uint8) + + # fill mask polygons + index = 0 + for Region in Regions: + # reformat Regions + Region[:,1] = np.int32(np.round((Region[:,1] - bounds['y_min_pad']) / downsample)) + Region[:,0] = np.int32(np.round((Region[:,0] - bounds['x_min_pad']) / downsample)) + # get annotation ID for mask color + ID = IDs[index] + cv2.fillPoly(mask, [Region], int(ID['annotationID'])) + index = index + 1 + + # reshape mask + x_start = np.int32(np.round((bounds['x_min'] - bounds['x_min_pad']) / downsample)) + y_start = np.int32(np.round((bounds['y_min'] - bounds['y_min_pad']) / downsample)) + x_stop = np.int32(np.round((bounds['x_max'] - bounds['x_min_pad']) / downsample)) + y_stop = np.int32(np.round((bounds['y_max'] - bounds['y_min_pad']) / downsample)) + # pull center mask region + mask = mask[ y_start:y_stop, x_start:x_stop ] + + else: # no Regions + mask = np.zeros([ int(np.round((bounds['y_max'] - bounds['y_min']) / downsample)), int(np.round((bounds['x_max'] - bounds['x_min']) / downsample)) ], dtype=np.uint8) + + return mask + +def write_minmax_to_xml(xml_path, tree=None, time_buffer=10, get_absolute_max=False): + # function to write min and max verticies to each region + + # parse xml and get root + if tree == None: tree = ET.parse(xml_path) + root = tree.getroot() + + try: + if get_absolute_max: + # break the try statement + X_max = 0 + Y_max = 0 + raise ValueError + + # has the xml been modified to include minmax + modtime = np.float64(root.attrib['modtime']) + # has the minmax modified xml been changed? + assert os.path.getmtime(xml_path) < modtime + time_buffer + + except: + + for Annotation in root.findall("./Annotation"): # for all annotations + annotationID = Annotation.attrib['Id'] + + for Region in Annotation.findall("./*/Region"): # iterate on all region + + for Vert in Region.findall("./Vertices"): # iterate on all vertex in region + Xs = [] + Ys = [] + for Vertex in Vert.findall("./Vertex"): # iterate on all vertex in region + # get points + Xs.append(np.int32(np.float64(Vertex.attrib['X']))) + Ys.append(np.int32(np.float64(Vertex.attrib['Y']))) + + # find min and max points + Xs = np.array(Xs) + Ys = np.array(Ys) + + if get_absolute_max: + # get the biggest point in annotation + if Xs != [] and Ys != []: + X_max = max(X_max, np.max(Xs)) + Y_max = max(Y_max, np.max(Ys)) + + else: + # modify the xml + Vert.set("Xmin", "{}".format(np.min(Xs))) + Vert.set("Xmax", "{}".format(np.max(Xs))) + Vert.set("Ymin", "{}".format(np.min(Ys))) + Vert.set("Ymax", "{}".format(np.max(Ys))) + + if get_absolute_max: + # return annotation max point + return (X_max,Y_max) + + else: + # modify the xml with minmax region info + root.set("modtime", "{}".format(time.time())) + xml_data = ET.tostring(tree, pretty_print=True) + #xml_data = Annotations.toprettyxml() + f = open(xml_path, 'w') + f.write(xml_data.decode()) + f.close() + + +def get_num_classes(xml_path,ignore_label=None): + # parse xml and get root + tree = ET.parse(xml_path) + root = tree.getroot() + + annotation_num = 0 + for Annotation in root.findall("./Annotation"): # for all annotations + if ignore_label != None: + if not int(Annotation.attrib['Id']) == ignore_label: + annotation_num += 1 + else: annotation_num += 1 + + return annotation_num + 1 diff --git a/setup.py b/setup.py index fff61e6..3581056 100644 --- a/setup.py +++ b/setup.py @@ -45,12 +45,12 @@ def prerelease_local_scheme(version): install_requires=[ # scientific packages 'nimfa>=1.3.2', - 'numpy>=1.21.1', + 'numpy==1.23.5', 'scipy>=0.19.0', 'Pillow==9.5.0', 'pandas>=0.19.2', 'imageio>=2.3.0', - # 'shapely[vectorized]', + 'shapely', #'opencv-python-headless<4.7', #'sqlalchemy', # 'matplotlib', @@ -69,6 +69,7 @@ def prerelease_local_scheme(version): # 'umap-learn==0.5.3', 'openpyxl', 'xlrd<2', + 'imgaug', # dask packages 'dask[dataframe]>=1.1.0', 'distributed>=1.21.6',