diff --git a/Dockerfile b/Dockerfile
index 4ca5afb..505d751 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -134,6 +134,7 @@ WORKDIR $mc_path/multic/cli
RUN python -m slicer_cli_web.cli_list_entrypoint --list_cli
RUN python -m slicer_cli_web.cli_list_entrypoint MultiCompartmentSegment --help
RUN python -m slicer_cli_web.cli_list_entrypoint FeatureExtraction --help
+RUN python -m slicer_cli_web.cli_list_entrypoint MultiCompartmentTrain --help
ENTRYPOINT ["/bin/bash", "docker-entrypoint.sh"]
diff --git a/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml b/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml
index 9c13982..005f086 100644
--- a/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml
+++ b/multic/cli/MultiCompartmentSegment/MultiCompartmentSegment.xml
@@ -4,7 +4,7 @@
Multi Compartment Segmentation
Segments multi-level structures from a whole-slide image
0.1.0
- https://github.com/SarderLab/deeplab-WSI
+ https://github.com/SarderLab/Multi-Compartment-Segmentation
Apache 2.0
Sayat Mimar (UFL)
This work is part of efforts in digital pathology by the Sarder Lab: UFL.
diff --git a/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.py b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.py
new file mode 100644
index 0000000..0c31306
--- /dev/null
+++ b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.py
@@ -0,0 +1,173 @@
+import os
+import sys
+from glob import glob
+import girder_client
+from ctk_cli import CLIArgumentParser
+
+sys.path.append("..")
+from segmentationschool.utils.mask_to_xml import xml_create, xml_add_annotation, xml_add_region, xml_save
+from segmentationschool.utils.xml_to_mask import write_minmax_to_xml
+
+NAMES = ['cortical_interstitium','medullary_interstitium','non_globally_sclerotic_glomeruli','globally_sclerotic_glomeruli','tubules','arteries/arterioles']
+
+
+def main(args):
+
+ folder = args.training_data_dir
+ base_dir_id = folder.split('/')[-2]
+ _ = os.system("printf '\nUsing data from girder_client Folder: {}\n'".format(folder))
+
+ _ = os.system("printf '\n---\n\nFOUND: [{}]\n'".format(args.init_modelfile))
+
+ gc = girder_client.GirderClient(apiUrl=args.girderApiUrl)
+ gc.setToken(args.girderToken)
+ # get files in folder
+ files = gc.listItem(base_dir_id)
+ xml_color=[65280]*(len(NAMES)+1)
+ cwd = os.getcwd()
+ print(cwd)
+ os.chdir(cwd)
+
+ tmp = folder
+
+ slides_used = []
+ ignore_label = len(NAMES)+1
+ for file in files:
+ slidename = file['name']
+ _ = os.system("printf '\n---\n\nFOUND: [{}]\n'".format(slidename))
+ skipSlide = 0
+
+ # get annotation
+ item = gc.getItem(file['_id'])
+ annot = gc.get('/annotation/item/{}'.format(item['_id']), parameters={'sort': 'updated'})
+ annot.reverse()
+ annot = list(annot)
+ _ = os.system("printf '\tfound [{}] annotation layers...\n'".format(len(annot)))
+
+ # create root for xml file
+ xmlAnnot = xml_create()
+
+ # all compartments
+ for class_,compart in enumerate(NAMES):
+
+ compart = compart.replace(' ','')
+ class_ +=1
+ # add layer to xml
+ xmlAnnot = xml_add_annotation(Annotations=xmlAnnot, xml_color=xml_color, annotationID=class_)
+
+ # test all annotation layers in order created
+ for iter,a in enumerate(annot):
+
+
+ try:
+ # check for annotation layer by name
+ a_name = a['annotation']['name'].replace(' ','')
+ except:
+ a_name = None
+
+ if a_name == compart:
+ # track all layers present
+ skipSlide +=1
+
+ pointsList = []
+
+ # load json data
+ _ = os.system("printf '\tloading annotation layer: [{}]\n'".format(compart))
+
+ a_data = a['annotation']['elements']
+
+ for data in a_data:
+ pointList = []
+ points = data['points']
+ for point in points:
+ pt_dict = {'X': round(point[0]), 'Y': round(point[1])}
+ pointList.append(pt_dict)
+ pointsList.append(pointList)
+
+ # write annotations to xml
+ for i in range(len(pointsList)):
+ pointList = pointsList[i]
+ xmlAnnot = xml_add_region(Annotations=xmlAnnot, pointList=pointList, annotationID=class_)
+
+ # print(a['_version'], a['updated'], a['created'])
+ break
+
+ if skipSlide != len(NAMES):
+ _ = os.system("printf '\tThis slide is missing annotation layers\n'")
+ _ = os.system("printf '\tSKIPPING SLIDE...\n'")
+ del xmlAnnot
+ continue # correct layers not present
+ # compart = 'ignore_label'
+ # # test all annotation layers in order created
+ # for iter,a in enumerate(annot):
+ # try:
+ # # check for annotation layer by name
+ # a_name = a['annotation']['name'].replace(' ','')
+ # except:
+ # a_name = None
+ # if a_name == compart:
+ # pointsList = []
+ # # load json data
+ # _ = os.system("printf '\tloading annotation layer: [{}]\n'".format(compart))
+ # a_data = a['annotation']['elements']
+ # for data in a_data:
+ # pointList = []
+ # if data['type'] == 'polyline':
+ # points = data['points']
+ # elif data['type'] == 'rectangle':
+ # center = data['center']
+ # width = data['width']/2
+ # height = data['height']/2
+ # points = [[ center[0]-width, center[1]-width ],[ center[0]+width, center[1]+width ]]
+ # for point in points:
+ # pt_dict = {'X': round(point[0]), 'Y': round(point[1])}
+ # pointList.append(pt_dict)
+ # pointsList.append(pointList)
+ # # write annotations to xml
+
+ # for i in range(len(pointsList)):
+ # pointList = pointsList[i]
+ # xmlAnnot = xml_add_region(Annotations=xmlAnnot, pointList=pointList, annotationID=ignore_label)
+ # break
+
+ # include slide and fetch annotations
+ _ = os.system("printf '\tFETCHING SLIDE...\n'")
+ os.rename('{}/{}'.format(folder, slidename), '{}/{}'.format(tmp, slidename))
+ slides_used.append(slidename)
+
+ xml_path = '{}/{}.xml'.format(tmp, os.path.splitext(slidename)[0])
+ _ = os.system("printf '\tsaving a created xml annotation file: [{}]\n'".format(xml_path))
+ xml_save(Annotations=xmlAnnot, filename=xml_path)
+ write_minmax_to_xml(xml_path) # to avoid trying to write to the xml from multiple workers
+ del xmlAnnot
+ os.system("ls -lh '{}'".format(tmp))
+
+ trainlogdir=os.path.join(tmp, 'output')
+ if not os.path.exists(trainlogdir):
+ os.makedirs(trainlogdir)
+
+ _ = os.system("printf '\ndone retriving data...\nstarting training...\n\n'")
+
+
+ cmd = "python3 ../segmentationschool/segmentation_school.py --option {} --training_data_dir {} --init_modelfile {} --gpu {} --train_steps {} --num_workers {} --girderApiUrl {} --girderToken {}".format('train', tmp.replace(' ', '\ '), args.init_modelfile, args.gpu, args.training_steps, args.num_workers, args.girderApiUrl, args.girderToken)
+ print(cmd)
+ sys.stdout.flush()
+ os.system(cmd)
+
+ os.listdir(trainlogdir)
+ os.chdir(trainlogdir)
+ os.system('pwd')
+ os.system('ls -lh')
+
+ filelist = glob('*.pth')
+ latest_model = max(filelist, key=os.path.getmtime)
+
+ _ = os.system("printf '\n{}\n'".format(latest_model))
+ os.rename(latest_model, args.output_model)
+
+ _ = os.system("printf '\nDone!\n\n'")
+
+
+
+if __name__ == "__main__":
+ main(CLIArgumentParser().parse_args())
diff --git a/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.xml b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.xml
new file mode 100644
index 0000000..9f70a39
--- /dev/null
+++ b/multic/cli/MultiCompartmentTrain/MultiCompartmentTrain.xml
@@ -0,0 +1,75 @@
+
+
+ HistomicsTK
+ Multi Compartment Training
+ Trains Multi compartment segmentation model
+ 0.1.0
+ https://github.com/SarderLab/Multi-Compartment-Segmentation
+ Apache 2.0
+ Sayat Mimar (UFL)
+ This work is part of efforts in digital pathology by the Sarder Lab: UFL.
+
+
+ Input/output parameters
+
+ training_data_dir
+
+ Base Directory for the model
+ input
+ 0
+
+
+ init_modelfile
+
+ Trained model file
+ input
+ 1
+
+
+ gpu
+
+ A comma separated list of the GPU IDs that will be made avalable for training
+ 0,1
+ 2
+
+
+ training_steps
+
+ The number of steps used for network training. The network will see [steps * batch size] image patches during training
+ 5000
+ 3
+
+
+ num_workers
+
+ Number of workers for Dataloader
+ 0
+ 4
+
+
+ output_model
+
+ Select the name of the output model file produced. By default this will be saved in your Private folder.
+ output
+ 5
+
+
+
+
+ A Girder API URL and token for Girder client
+
+ girderApiUrl
+ api-url
+
+ A Girder API URL (e.g., https://girder.example.com:443/api/v1)
+
+
+
+ girderToken
+ token
+
+ A Girder token
+
+
+
+
diff --git a/multic/cli/slicer_cli_list.json b/multic/cli/slicer_cli_list.json
index 1189b00..b6a048f 100644
--- a/multic/cli/slicer_cli_list.json
+++ b/multic/cli/slicer_cli_list.json
@@ -4,5 +4,8 @@
},
"FeatureExtraction": {
"type" : "python"
+ },
+ "MultiCompartmentTrain": {
+ "type" : "python"
}
}
diff --git a/multic/segmentationschool/Codes/IterativeTraining_1X.py b/multic/segmentationschool/Codes/IterativeTraining_1X.py
index 9c98890..0b92135 100644
--- a/multic/segmentationschool/Codes/IterativeTraining_1X.py
+++ b/multic/segmentationschool/Codes/IterativeTraining_1X.py
@@ -1,677 +1,343 @@
-import os, sys, cv2, time, random, warnings, multiprocessing#json,# detectron2
+import os,cv2, time, random, multiprocessing,copy
+from skimage.color import rgb2hsv,hsv2rgb,rgb2lab,lab2rgb
import numpy as np
-import matplotlib.pyplot as plt
-import lxml.etree as ET
-from matplotlib import path
-from skimage.transform import resize
-from skimage.io import imread, imsave
-import glob
-from .getWsi import getWsi
-
-from .xml_to_mask2 import get_supervision_boxes, regions_in_mask_dots, get_vertex_points_dots, masks_from_points, restart_line
-from joblib import Parallel, delayed
-from shutil import move
+from tiffslide import TiffSlide
+from .xml_to_mask_minmax import xml_to_mask
# from generateTrainSet import generateDatalists
-#from subprocess import call
-#from .get_choppable_regions import get_choppable_regions
-from PIL import Image
-
+import logging
from detectron2.utils.logger import setup_logger
+from skimage import exposure
+
setup_logger()
from detectron2 import model_zoo
-from detectron2.engine import DefaultPredictor,DefaultTrainer
+from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
-from detectron2.utils.visualizer import Visualizer#,ColorMode
-from detectron2.data import MetadataCatalog, DatasetCatalog
-#from detectron2.structures import BoxMode
-from .get_dataset_list import HAIL2Detectron, samples_from_json, samples_from_json_mini
-#from detectron2.checkpoint import DetectionCheckpointer
-#from detectron2.modeling import build_model
-
-"""
-
-Code for - cutting / augmenting / training CNN
-
-This uses WSI and XML files to train 2 neural networks for semantic segmentation
- of histopath tissue via human in the loop training
-
-"""
-
+# from detectron2.data import MetadataCatalog, DatasetCatalog
+from detectron2.data import detection_utils as utils
+import detectron2.data.transforms as T
+from detectron2.structures import BoxMode
+from detectron2.data import (DatasetCatalog,
+ MetadataCatalog,
+ build_detection_test_loader,
+ build_detection_train_loader,
+)
+from detectron2.config import configurable
+from typing import List, Optional, Union
+import torch
+
+# sys.append("..")
+from .wsi_loader_utils import train_samples_from_WSI, get_slide_data, get_random_chops
+from imgaug import augmenters as iaa
+
+
+global seq
+seq = iaa.Sequential([
+ iaa.Sometimes(0.5,iaa.OneOf([
+ iaa.AddElementwise((-15,15),per_channel=0.5),
+ iaa.ImpulseNoise(0.05),iaa.CoarseDropout(0.02, size_percent=0.5)])),
+ iaa.Sometimes(0.5,iaa.OneOf([iaa.GaussianBlur(sigma=(0, 3.0)),
+ iaa.Sharpen(alpha=(0.0, 1.0), lightness=(0.75, 2.0))]))
+])
#Record start time
totalStart=time.time()
def IterateTraining(args):
- ## calculate low resolution block params
- downsampleLR = int(args.downsampleRateLR**.5) #down sample for each dimension
- region_sizeLR = int(args.boxSizeLR*(downsampleLR)) #Region size before downsampling
- stepLR = int(region_sizeLR*(1-args.overlap_percentLR)) #Step size before downsampling
- ## calculate low resolution block params
- downsampleHR = int(args.downsampleRateHR**.5) #down sample for each dimension
- region_sizeHR = int(args.boxSizeHR*(downsampleHR)) #Region size before downsampling
- stepHR = int(region_sizeHR*(1-args.overlap_percentHR)) #Step size before downsampling
-
- global classNum_HR,classEnumLR,classEnumHR
+
+ region_size = int(args.boxSize) #Region size before downsampling
+
dirs = {'imExt': '.jpeg'}
dirs['basedir'] = args.base_dir
dirs['maskExt'] = '.png'
- dirs['modeldir'] = '/MODELS/'
- dirs['tempdirLR'] = '/TempLR/'
- dirs['tempdirHR'] = '/TempHR/'
- dirs['pretraindir'] = '/Deeplab_network/'
- dirs['training_data_dir'] = '/TRAINING_data/'
- dirs['model_init'] = 'deeplab_resnet.ckpt'
- dirs['project']= '/' + args.project
- dirs['data_dir_HR'] = args.base_dir +'/' + args.project + '/Permanent/HR/'
- dirs['data_dir_LR'] = args.base_dir +'/' +args.project + '/Permanent/LR/'
+ dirs['training_data_dir'] = args.training_data_dir
- ##All folders created, initiate WSI loading by human
- #raw_input('Please place WSIs in ')
- ##Check iteration session
- currentmodels=os.listdir(dirs['basedir'] + dirs['project'] + dirs['modeldir'])
- print('Handcoded iteration')
- # currentAnnotationIteration=check_model_generation(dirs)
- currentAnnotationIteration=2
- print('Current training session is: ' + str(currentAnnotationIteration))
-
- ##Create objects for storing class distributions
- annotatedXMLs=glob.glob(dirs['basedir'] + dirs['project'] + dirs['training_data_dir'] + str(currentAnnotationIteration) + '/*.xml')
- classes=[]
+ print('Handcoded iteration')
- if args.classNum == 0:
- for xml in annotatedXMLs:
- classes.append(get_num_classes(xml))
- classNum_HR = max(classes)
+ os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
+ os.environ["CUDA_LAUNCH_BLOCKING"] ='1'
+
+ organType='kidney'
+ print('Organ meta being set to... '+ organType)
+ if organType=='liver':
+ classnames=['Background','BD','A']
+ isthing=[0,1,1]
+ xml_color = [[0,255,0], [0,255,255], [0,0,255]]
+ tc=['BD','AT']
+ sc=['Ob','B']
+ elif organType =='kidney':
+ classnames=['interstitium','medulla','glomerulus','sclerotic glomerulus','tubule','arterial tree']
+ classes={}
+ isthing=[0,0,1,1,1,1]
+ xml_color = [[0,255,0], [0,255,255], [255,255,0],[0,0,255], [255,0,0], [0,128,255]]
+ tc=['G','SG','T','A']
+ sc=['Ob','I','M','B']
else:
- classNum_LR = args.classNum
- if args.classNum_HR != 0:
- classNum_HR = args.classNum_HR
- else:
- classNum_HR = classNum_LR
-
- classNum_HR=args.classNum
-
- ##for all WSIs in the initiating directory:
- if args.chop_data == 'True':
- print('Chopping')
-
- start=time.time()
- size_data=[]
-
- for xmlID in annotatedXMLs:
-
- #Get unique name of WSI
- fileID=xmlID.split('/')[-1].split('.xml')[0]
- print('-----------------'+fileID+'----------------')
- #create memory addresses for wsi files
- for ext in [args.wsi_ext]:
- wsiID=dirs['basedir'] + dirs['project']+ dirs['training_data_dir'] + str(currentAnnotationIteration) +'/'+ fileID + ext
-
- #Ensure annotations exist
- if os.path.isfile(wsiID)==True:
- break
-
-
- #Load openslide information about WSI
- if ext != '.tif':
- slide=getWsi(wsiID)
- #WSI level 0 dimensions (largest size)
- dim_x,dim_y=slide.dimensions
- else:
- im = Image.open(wsiID)
- dim_x, dim_y=im.size
- location=[0,0]
- size=[dim_x,dim_y]
- tree = ET.parse(xmlID)
- root = tree.getroot()
- box_supervision_layers=['8']
- # calculate region bounds
- global_bounds = {'x_min' : location[0], 'y_min' : location[1], 'x_max' : location[0] + size[0], 'y_max' : location[1] + size[1]}
- local_bounds = get_supervision_boxes(root,box_supervision_layers)
- num_cores = multiprocessing.cpu_count()
- Parallel(n_jobs=num_cores)(delayed(chop_suey_bounds)(args=args,wsiID=wsiID,
- dirs=dirs,lb=lb,xmlID=xmlID,box_supervision_layers=box_supervision_layers) for lb in tqdm(local_bounds))
- # for lb in tqdm(local_bounds):
-
- # size_data.extend(image_sizes)
-
- '''
- wsi_mask=xml_to_mask(xmlID, [0,0], [dim_x,dim_y])
+ print('Provided organType not in supported types: kidney, liver')
- #Enumerate cpu core count
- num_cores = multiprocessing.cpu_count()
+ classNum=len(tc)+len(sc)-1
+ print('Number classes: '+ str(classNum))
+ classes={}
- #Generate iterators for parallel chopping of WSIs in high resolution
- #index_yHR=range(30240,dim_y-stepHR,stepHR)
- #index_xHR=range(840,dim_x-stepHR,stepHR)
- index_yHR=range(0,dim_y,stepHR)
- index_xHR=range(0,dim_x,stepHR)
- index_yHR[-1]=dim_y-stepHR
- index_xHR[-1]=dim_x-stepHR
- #Create memory address for chopped images high resolution
- outdirHR=dirs['basedir'] + dirs['project'] + dirs['tempdirHR']
+ for idx,c in enumerate(classnames):
+ classes[idx]={'isthing':isthing[idx],'color':xml_color[idx]}
- #Perform high resolution chopping in parallel and return the number of
- #images in each of the labeled classes
- chop_regions=get_choppable_regions(wsi=wsiID,
- index_x=index_xHR,index_y=index_yHR,boxSize=region_sizeHR,white_percent=args.white_percent)
- Parallel(n_jobs=num_cores)(delayed(return_region)(args=args,
- wsi_mask=wsi_mask, wsiID=wsiID,
- fileID=fileID, yStart=j, xStart=i, idxy=idxy,
- idxx=idxx, downsampleRate=args.downsampleRateHR,
- outdirT=outdirHR, region_size=region_sizeHR,
- dirs=dirs, chop_regions=chop_regions,classNum_HR=classNum_HR) for idxx,i in enumerate(index_xHR) for idxy,j in enumerate(index_yHR))
-
- #for idxx,i in enumerate(index_xHR):
- # for idxy,j in enumerate(index_yHR):
- # if chop_regions[idxy,idxx] != 0:
- # return_region(args=args,xmlID=xmlID, wsiID=wsiID, fileID=fileID, yStart=j, xStart=i,idxy=idxy, idxx=idxx,
- # downsampleRate=args.downsampleRateHR,outdirT=outdirHR, region_size=region_sizeHR, dirs=dirs,
- # chop_regions=chop_regions,classNum_HR=classNum_HR)
- # else:
- # print('pass')
-
-
- # exit()
- print('Time for WSI chopping: ' + str(time.time()-start))
-
- classEnumHR=np.ones([classNum_HR,1])*classNum_HR
-
- ##High resolution augmentation
- #Enumerate high resolution class distribution
- classDistHR=np.zeros(len(classEnumHR))
- for idx,value in enumerate(classEnumHR):
- classDistHR[idx]=value/sum(classEnumHR)
- print(classDistHR)
- #Define number of augmentations per class
-
- moveimages(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/regions/', dirs['basedir']+dirs['project'] + '/Permanent/HR/regions/')
- moveimages(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/masks/',dirs['basedir']+dirs['project'] + '/Permanent/HR/masks/')
-
-
- #Total time
- print('Time for high resolution augmenting: ' + str((time.time()-totalStart)/60) + ' minutes.')
- '''
-
- # with open('sizes.csv','w',newline='') as myfile:
- # wr=csv.writer(myfile,quoting=csv.QUOTE_ALL)
- # wr.writerow(size_data)
- # pretrain_HR=get_pretrain(currentAnnotationIteration,'/HR/',dirs)
-
- modeldir_HR = dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration+1) + '/HR/'
-
-
- ##### HIGH REZ ARGS #####
- dirs['outDirAIHR']=dirs['basedir']+'/'+dirs['project'] + '/Permanent/HR/regions/'
- dirs['outDirAMHR']=dirs['basedir']+'/'+dirs['project'] + '/Permanent/HR/masks/'
-
-
- numImagesHR=len(glob.glob(dirs['outDirAIHR'] + '*' + dirs['imExt']))
-
- numStepsHR=(args.epoch_HR*numImagesHR)/ args.CNNbatch_sizeHR
-
-
- #-----------------------------------------------------------------------------------------
- # os.environ["CUDA_VISIBLE_DEVICES"]='0'
- os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)
- # img_dir='/hdd/bg/Detectron2/chop_detectron/Permanent/HR'
-
- img_dir=dirs['outDirAIHR']
- classnames=['Background','BD','A']
- isthing=[0,1,1]
- xml_color = [[0,255,0], [0,255,255], [0,0,255]]
-
- rand_sample=True
-
- json_file=img_dir+'/detectron_train.json'
- HAIL2Detectron(img_dir,rand_sample,json_file,classnames,isthing,xml_color)
- tc=['BD','AT']
- sc=['I','B']
- #### From json
- DatasetCatalog.register("my_dataset", lambda:samples_from_json(json_file,rand_sample))
+ num_images=args.batch_size*args.train_steps
+ # slide_idxs=train_dset.get_random_slide_idx(num_images)
+ usable_slides=get_slide_data(args, wsi_directory = dirs['training_data_dir'])
+ print('Number of slides:', len(usable_slides))
+ usable_idx=range(0,len(usable_slides))
+ slide_idxs=random.choices(usable_idx,k=num_images)
+ image_coordinates=get_random_chops(slide_idxs,usable_slides,region_size)
+
+ DatasetCatalog.register("my_dataset", lambda:train_samples_from_WSI(args,image_coordinates))
MetadataCatalog.get("my_dataset").set(thing_classes=tc)
MetadataCatalog.get("my_dataset").set(stuff_classes=sc)
+
- seg_metadata=MetadataCatalog.get("my_dataset")
-
-
- # new_list = DatasetCatalog.get("my_dataset")
- # print(len(new_list))
- # for d in random.sample(new_list, 100):
- #
- # img = cv2.imread(d["file_name"])
- # visualizer = Visualizer(img[:, :, ::-1],metadata=seg_metadata, scale=0.5)
- # out = visualizer.draw_dataset_dict(d)
- # cv2.namedWindow("output", cv2.WINDOW_NORMAL)
- # cv2.imshow("output",out.get_image()[:, :, ::-1])
- # cv2.waitKey(0) # waits until a key is pressed
- # cv2.destroyAllWindows()
- # exit()
+ _ = os.system("printf '\nTraining starts...\n'")
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml"))
cfg.DATASETS.TRAIN = ("my_dataset")
- cfg.DATASETS.TEST = ()
- num_cores = multiprocessing.cpu_count()
- cfg.DATALOADER.NUM_WORKERS = num_cores-3
- # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml") # Let training initialize from model zoo
- cfg.MODEL.WEIGHTS = os.path.join('/hdd/bg/Detectron2/HAIL_Detectron2/liver/MODELS/0/HR', "model_final.pth")
+ #cfg.TEST.EVAL_PERIOD=args.eval_period
+ #cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.50
+ #num_cores = multiprocessing.cpu_count()
+ cfg.DATALOADER.NUM_WORKERS = args.num_workers
+
+ if args.init_modelfile:
+ cfg.MODEL.WEIGHTS = args.init_modelfile
+ else:
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml") # Let training initialize from model zoo
- cfg.SOLVER.IMS_PER_BATCH = 10
+ cfg.SOLVER.IMS_PER_BATCH = args.batch_size
- # cfg.SOLVER.BASE_LR = 0.02 # pick a good LR
- # cfg.SOLVER.LR_policy='steps_with_lrs'
- # cfg.SOLVER.MAX_ITER = 50000
- # cfg.SOLVER.STEPS = [30000,40000]
- # # cfg.SOLVER.STEPS = []
- # cfg.SOLVER.LRS = [0.002,0.0002]
- cfg.SOLVER.BASE_LR = 0.002 # pick a good LR
cfg.SOLVER.LR_policy='steps_with_lrs'
- cfg.SOLVER.MAX_ITER = 200000
- cfg.SOLVER.STEPS = [150000,180000]
- # cfg.SOLVER.STEPS = []
- cfg.SOLVER.LRS = [0.0002,0.00002]
-
- # cfg.INPUT.CROP.ENABLED = True
- # cfg.INPUT.CROP.TYPE='absolute'
- # cfg.INPUT.CROP.SIZE=[100,100]
- cfg.MODEL.BACKBONE.FREEZE_AT = 0
- # cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[4],[8],[16], [32], [64], [64], [64]]
- # cfg.MODEL.RPN.IN_FEATURES = ['p2', 'p2', 'p2', 'p3','p4','p5','p6']
- cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.33, 0.5, 1.0, 2.0, 3.0]]
+ cfg.SOLVER.MAX_ITER = args.train_steps
+ cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
+ cfg.SOLVER.LRS = [0.000025,0.0000025]
+ cfg.SOLVER.STEPS = [70000,90000]
+ cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32],[64],[128], [256], [512], [1024]]
+ cfg.MODEL.RPN.IN_FEATURES = ['p2', 'p3', 'p4', 'p5','p6','p6']
+ cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[.1,.2,0.33, 0.5, 1.0, 2.0, 3.0,5,10]]
cfg.MODEL.ANCHOR_GENERATOR.ANGLES=[-90,-60,-30,0,30,60,90]
-
- cfg.MODEL.RPN.POSITIVE_FRACTION = 0.75
-
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(tc)
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES =len(sc)
-
-
- # cfg.INPUT.CROP.ENABLED = True
- # cfg.INPUT.CROP.TYPE='absolute'
- # cfg.INPUT.CROP.SIZE=[64,64]
-
- cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256 # faster, and good enough for this toy dataset (default: 512)
- # cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
-
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS=False
- cfg.INPUT.MIN_SIZE_TRAIN=0
- # cfg.INPUT.MAX_SIZE_TRAIN=500
-
- # exit()
- os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
+ cfg.INPUT.MIN_SIZE_TRAIN=args.boxSize
+ cfg.INPUT.MAX_SIZE_TRAIN=args.boxSize
+ cfg.DATASETS.TEST = ()
+ cfg.OUTPUT_DIR = args.training_data_dir+"/output"
+ #os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
with open(cfg.OUTPUT_DIR+"/config_record.yaml", "w") as f:
f.write(cfg.dump()) # save config to file
- trainer = DefaultTrainer(cfg)
+ trainer = Trainer(cfg)
+
trainer.resume_or_load(resume=False)
trainer.train()
- cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.01 # set a custom testing threshold
- cfg.TEST.DETECTIONS_PER_IMAGE = 500
- #
- # cfg.INPUT.MIN_SIZE_TRAIN=64
- # cfg.INPUT.MAX_SIZE_TRAIN=4000
- cfg.INPUT.MIN_SIZE_TEST=64
- cfg.INPUT.MAX_SIZE_TEST=500
-
-
- predict_samples=100
- predictor = DefaultPredictor(cfg)
-
- dataset_dicts = samples_from_json_mini(json_file,predict_samples)
- iter=0
- if not os.path.exists(os.getcwd()+'/network_predictions/'):
- os.mkdir(os.getcwd()+'/network_predictions/')
- for d in random.sample(dataset_dicts, predict_samples):
- # print(d["file_name"])
- # imclass=d["file_name"].split('/')[-1].split('_')[-5].split(' ')[-1]
- # if imclass in ["TRI","HE"]:
- im = cv2.imread(d["file_name"])
- panoptic_seg, segments_info = predictor(im)["panoptic_seg"] # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
- # print(segments_info)
- # plt.imshow(panoptic_seg.to("cpu"))
- # plt.show()
- v = Visualizer(im[:, :, ::-1], seg_metadata, scale=1.2)
- v = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info)
- # panoptic segmentation result
- # plt.ion()
- plt.subplot(121)
- plt.imshow(im[:, :, ::-1])
- plt.subplot(122)
- plt.imshow(v.get_image())
- plt.savefig(f"./network_predictions/input_{iter}.jpg",dpi=300)
- plt.show()
- # plt.ioff()
-
-
- # v = Visualizer(im[:, :, ::-1],
- # metadata=seg_metadata,
- # scale=0.5,
- # )
- # out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"),segments_info)
-
- # imsave('./network_predictions/pred'+str(iter)+'.png',np.hstack((im,v.get_image())))
- iter=iter+1
- # cv2.imshow('',out.get_image()[:, :, ::-1])
- # cv2.waitKey(0) # waits until a key is pressed
- # cv2.destroyAllWindows()
- #-----------------------------------------------------------------------------------------
-
- finish_model_generation(dirs,currentAnnotationIteration)
-
- print('\n\n\033[92;5mPlease place new wsi file(s) in: \n\t' + dirs['basedir'] + dirs['project']+ dirs['training_data_dir'] + str(currentAnnotationIteration+1))
- print('\nthen run [--option predict]\033[0m\n')
-
-
-
-
-def moveimages(startfolder,endfolder):
- filelist=glob.glob(startfolder + '*')
- for file in filelist:
- fileID=file.split('/')[-1]
- move(file,endfolder + fileID)
-
-
-def check_model_generation(dirs):
- modelsCurrent=os.listdir(dirs['basedir'] + dirs['project'] + dirs['modeldir'])
- gens=map(int,modelsCurrent)
- modelOrder=np.sort(gens)[::-1]
-
- for idx in modelOrder:
- #modelsChkptsLR=glob.glob(dirs['basedir'] + dirs['project'] + dirs['modeldir']+str(modelsCurrent[idx]) + '/LR/*.ckpt*')
- modelsChkptsHR=glob.glob(dirs['basedir'] + dirs['project'] + dirs['modeldir']+ str(idx) +'/HR/*.ckpt*')
- if modelsChkptsHR == []:
- continue
- else:
- return idx
- break
-
-def finish_model_generation(dirs,currentAnnotationIteration):
- make_folder(dirs['basedir'] + dirs['project'] + dirs['training_data_dir'] + str(currentAnnotationIteration + 1))
-
-def get_pretrain(currentAnnotationIteration,res,dirs):
-
- if currentAnnotationIteration==0:
- pretrain_file = glob.glob(dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + '*')
- pretrain_file=pretrain_file[0].split('.')[0] + '.' + pretrain_file[0].split('.')[1]
-
- else:
- pretrains=glob.glob(dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + 'model*')
- maxmodel=0
- for modelfiles in pretrains:
- modelID=modelfiles.split('.')[-2].split('-')[1]
- if int(modelID)>maxmodel:
- maxmodel=int(modelID)
- pretrain_file=dirs['basedir']+dirs['project'] + dirs['modeldir'] + str(currentAnnotationIteration) + res + 'model.ckpt-' + str(maxmodel)
- return pretrain_file
-
-def restart_line(): # for printing chopped image labels in command line
- sys.stdout.write('\r')
- sys.stdout.flush()
-
-def file_len(fname): # get txt file length (number of lines)
- with open(fname) as f:
- for i, l in enumerate(f):
- pass
- return i + 1
-
-def make_folder(directory):
- if not os.path.exists(directory):
- os.makedirs(directory) # make directory if it does not exit already # make new directory # Check if folder exists, if not make it
-
-def make_all_folders(dirs):
-
-
- make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/regions')
- make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/masks')
-
- make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/Augment' +'/regions')
- make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirLR'] + '/Augment' +'/masks')
-
- make_folder(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/regions')
- make_folder(dirs['basedir'] +dirs['project']+ dirs['tempdirHR'] + '/masks')
-
- make_folder(dirs['basedir']+dirs['project'] + dirs['tempdirHR'] + '/Augment' +'/regions')
- make_folder(dirs['basedir']+dirs['project']+ dirs['tempdirHR'] + '/Augment' +'/masks')
-
- make_folder(dirs['basedir'] +dirs['project']+ dirs['modeldir'])
- make_folder(dirs['basedir'] +dirs['project']+ dirs['training_data_dir'])
-
-
- make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/LR/'+ 'regions/')
- make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/LR/'+ 'masks/')
- make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/HR/'+ 'regions/')
- make_folder(dirs['basedir'] +dirs['project']+ '/Permanent' +'/HR/'+ 'masks/')
-
- make_folder(dirs['basedir'] +dirs['project']+ dirs['training_data_dir'])
-
- make_folder(dirs['basedir'] + '/Codes/Deeplab_network/datasetLR')
- make_folder(dirs['basedir'] + '/Codes/Deeplab_network/datasetHR')
-
-def return_region(args, wsi_mask, wsiID, fileID, yStart, xStart, idxy, idxx, downsampleRate, outdirT, region_size, dirs, chop_regions,classNum_HR): # perform cutting in parallel
- sys.stdout.write(' <'+str(xStart)+'/'+ str(yStart)+'/'+str(chop_regions[idxy,idxx] != 0)+ '> ')
- sys.stdout.flush()
- restart_line()
-
- if chop_regions[idxy,idxx] != 0:
-
- uniqID=fileID + str(yStart) + str(xStart)
- if wsiID.split('.')[-1] != 'tif':
- slide=getWsi(wsiID)
- Im=np.array(slide.read_region((xStart,yStart),0,(region_size,region_size)))
- Im=Im[:,:,:3]
- else:
- yEnd = yStart + region_size
- xEnd = xStart + region_size
- Im = np.zeros([region_size,region_size,3], dtype=np.uint8)
- Im_ = imread(wsiID)[yStart:yEnd, xStart:xEnd, :3]
- Im[0:Im_.shape[0], 0:Im_.shape[1], :] = Im_
-
- mask_annotation=wsi_mask[yStart:yStart+region_size,xStart:xStart+region_size]
-
- o1,o2=mask_annotation.shape
- if o1 !=region_size:
- mask_annotation=np.pad(mask_annotation,((0,region_size-o1),(0,0)),mode='constant')
- if o2 !=region_size:
- mask_annotation=np.pad(mask_annotation,((0,0),(0,region_size-o2)),mode='constant')
-
- '''
- if 4 in np.unique(mask_annotation):
- plt.subplot(121)
- plt.imshow(mask_annotation*20)
- plt.subplot(122)
- plt.imshow(Im)
- pt=[xStart,yStart]
- plt.title(pt)
- plt.show()
- '''
- if downsampleRate !=1:
- c=(Im.shape)
- s1=int(c[0]/(downsampleRate**.5))
- s2=int(c[1]/(downsampleRate**.5))
- Im=resize(Im,(s1,s2),mode='reflect')
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- imsave(outdirT + '/regions/' + uniqID + dirs['imExt'],Im)
- imsave(outdirT + '/masks/' + uniqID +dirs['maskExt'],mask_annotation)
-
-
-def regions_in_mask(root, bounds, verbose=1):
- # find regions to save
- IDs_reg = []
- IDs_points = []
-
- for Annotation in root.findall("./Annotation"): # for all annotations
- annotationID = Annotation.attrib['Id']
- annotationType = Annotation.attrib['Type']
-
- # print(Annotation.findall(./))
- if annotationType =='9':
- for element in Annotation.iter('InputAnnotationId'):
- pointAnnotationID=element.text
-
- for Region in Annotation.findall("./*/Region"): # iterate on all region
-
- for Vertex in Region.findall("./*/Vertex"): # iterate on all vertex in region
- # get points
- x_point = np.int32(np.float64(Vertex.attrib['X']))
- y_point = np.int32(np.float64(Vertex.attrib['Y']))
- # test if points are in bounds
- if bounds['x_min'] <= x_point <= bounds['x_max'] and bounds['y_min'] <= y_point <= bounds['y_max']: # test points in region bounds
- # save region Id
- IDs_points.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID,'pointAnnotationID':pointAnnotationID})
- break
- elif annotationType=='4':
-
- for Region in Annotation.findall("./*/Region"): # iterate on all region
-
- for Vertex in Region.findall("./*/Vertex"): # iterate on all vertex in region
- # get points
- x_point = np.int32(np.float64(Vertex.attrib['X']))
- y_point = np.int32(np.float64(Vertex.attrib['Y']))
- # test if points are in bounds
- if bounds['x_min'] <= x_point <= bounds['x_max'] and bounds['y_min'] <= y_point <= bounds['y_max']: # test points in region bounds
- # save region Id
- IDs_reg.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID})
- break
- return IDs_reg,IDs_points
-
-
-def get_vertex_points(root, IDs_reg,IDs_points, maskModes,excludedIDs,negativeIDs=None):
- Regions = []
- Points = []
-
- for ID in IDs_reg:
- Vertices = []
- if ID['annotationID'] not in excludedIDs:
- for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"):
- Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))])
- Regions.append({'Vertices':np.array(Vertices),'annotationID':ID['annotationID']})
-
- for ID in IDs_points:
- Vertices = []
- for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"):
- Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))])
- Points.append({'Vertices':np.array(Vertices),'pointAnnotationID':ID['pointAnnotationID']})
- if 'falsepositive' or 'negative' in maskModes:
- assert negativeIDs is not None,'Negatively annotated classes must be provided for negative/falsepositive mask mode'
- assert 'falsepositive' and 'negative' not in maskModes, 'Negative and false positive mask modes cannot both be true'
-
- useableRegions=[]
- if 'positive' in maskModes:
- for Region in Regions:
- regionPath=path.Path(Region['Vertices'])
- for Point in Points:
- if Region['annotationID'] not in negativeIDs:
- if regionPath.contains_point(Point['Vertices'][0]):
- Region['pointAnnotationID']=Point['pointAnnotationID']
- useableRegions.append(Region)
-
- if 'negative' in maskModes:
-
- for Region in Regions:
- regionPath=path.Path(Region['Vertices'])
- if Region['annotationID'] in negativeIDs:
- if not any([regionPath.contains_point(Point['Vertices'][0]) for Point in Points]):
- Region['pointAnnotationID']=Region['annotationID']
- useableRegions.append(Region)
- if 'falsepositive' in maskModes:
-
- for Region in Regions:
- regionPath=path.Path(Region['Vertices'])
- if Region['annotationID'] in negativeIDs:
- if not any([regionPath.contains_point(Point['Vertices'][0]) for Point in Points]):
- Region['pointAnnotationID']=0
- useableRegions.append(Region)
-
- return useableRegions
-def chop_suey_bounds(lb,xmlID,box_supervision_layers,wsiID,dirs,args):
- tree = ET.parse(xmlID)
- root = tree.getroot()
- lbVerts=np.array(lb['BoxVerts'])
- xMin=min(lbVerts[:,0])
- xMax=max(lbVerts[:,0])
- yMin=min(lbVerts[:,1])
- yMax=max(lbVerts[:,1])
-
- # test=np.array(slide.read_region((xMin,yMin),0,(xMax-xMin,yMax-yMin)))[:,:,:3]
-
- local_bound = {'x_min' : xMin, 'y_min' : yMin, 'x_max' : xMax, 'y_max' : yMax}
- IDs_reg,IDs_points = regions_in_mask_dots(root=root, bounds=local_bound,box_layers=box_supervision_layers)
-
- # find regions in bounds
- negativeIDs=['4']
- excludedIDs=['1']
- falsepositiveIDs=['4']
- usableRegions= get_vertex_points_dots(root=root, IDs_reg=IDs_reg,IDs_points=IDs_points,excludedIDs=excludedIDs,maskModes=['falsepositive','positive'],negativeIDs=negativeIDs,
- falsepositiveIDs=falsepositiveIDs)
-
- # image_sizes=
- masks_from_points(usableRegions,wsiID,dirs,50,args,[xMin,xMax,yMin,yMax])
-'''
-def masks_from_points(root,usableRegions,wsiID,dirs):
- pas_img = getWsi(wsiID)
- image_sizes=[]
- basename=wsiID.split('/')[-1].split('.svs')[0]
-
- for usableRegion in tqdm(usableRegions):
- vertices=usableRegion['Vertices']
- x1=min(vertices[:,0])
- x2=max(vertices[:,0])
- y1=min(vertices[:,1])
- y2=max(vertices[:,1])
- points = np.stack([np.asarray(vertices[:,0]), np.asarray(vertices[:,1])], axis=1)
- if (x2-x1)>0 and (y2-y1)>0:
- l1=x2-x1
- l2=y2-y1
- xMultiplier=np.ceil((l1)/64)
- yMultiplier=np.ceil((l2)/64)
- pad1=int(xMultiplier*64-l1)
- pad2=int(yMultiplier*64-l2)
-
- points[:,1] = np.int32(np.round(points[:,1] - y1 ))
- points[:,0] = np.int32(np.round(points[:,0] - x1 ))
- mask = 2*np.ones([y2-y1,x2-x1], dtype=np.uint8)
- if int(usableRegion['pointAnnotationID'])==0:
- pass
- else:
- cv2.fillPoly(mask, [points], int(usableRegion['pointAnnotationID'])-4)
- PAS = pas_img.read_region((x1,y1), 0, (x2-x1,y2-y1))
- # print(usableRegion['pointAnnotationID'])
- PAS = np.array(PAS)[:,:,0:3]
- mask=np.pad( mask,((0,pad2),(0,pad1)),'constant',constant_values=(2,2) )
- PAS=np.pad( PAS,((0,pad2),(0,pad1),(0,0)),'constant',constant_values=(0,0) )
-
- image_identifier=basename+'_'.join(['',str(x1),str(y1),str(l1),str(l2)])
- mask_out_name=dirs['basedir']+dirs['project'] + '/Permanent/HR/masks/'+image_identifier+'.png'
- image_out_name=mask_out_name.replace('/masks/','/regions/')
- # basename + '_' + str(image_identifier) + args.imBoxExt
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- imsave(image_out_name,PAS)
- imsave(mask_out_name,mask)
- # exit()
- # extract image region
- # plt.subplot(121)
- # plt.imshow(PAS)
- # plt.subplot(122)
- # plt.imshow(mask)
- # plt.show()
- # image_sizes.append([x2-x1,y2-y1])
+ _ = os.system("printf '\nTraining completed!\n'")
+
+def mask2polygons(mask):
+ annotation=[]
+ presentclasses=np.unique(mask)
+ offset=-3
+ presentclasses=presentclasses[presentclasses>2]
+ presentclasses=list(presentclasses[presentclasses<7])
+ for p in presentclasses:
+ contours, hierarchy = cv2.findContours(np.array(mask==p, dtype='uint8'), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
+ for contour in contours:
+ if contour.size>=6:
+ instance_dict={}
+ contour_flat=contour.flatten().astype('float').tolist()
+ xMin=min(contour_flat[::2])
+ yMin=min(contour_flat[1::2])
+ xMax=max(contour_flat[::2])
+ yMax=max(contour_flat[1::2])
+ instance_dict['bbox']=[xMin,yMin,xMax,yMax]
+ instance_dict['bbox_mode']=BoxMode.XYXY_ABS.value
+ instance_dict['category_id']=p+offset
+ instance_dict['segmentation']=[contour_flat]
+ annotation.append(instance_dict)
+ return annotation
+
+
+class Trainer(DefaultTrainer):
+
+ @classmethod
+ def build_test_loader(cls, cfg, dataset_name):
+ return build_detection_test_loader(cfg, dataset_name, mapper=CustomDatasetMapper(cfg, True))
+
+ @classmethod
+ def build_train_loader(cls, cfg):
+ return build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True))
+
+
+class CustomDatasetMapper:
+
+ @configurable
+ def __init__(
+ self,
+ is_train: bool,
+ *,
+ augmentations: List[Union[T.Augmentation, T.Transform]],
+ image_format: str,
+ use_instance_mask: bool = False,
+ use_keypoint: bool = False,
+ instance_mask_format: str = "polygon",
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
+ precomputed_proposal_topk: Optional[int] = None,
+ recompute_boxes: bool = False,
+ ):
+ """
+ NOTE: this interface is experimental.
+
+ Args:
+ is_train: whether it's used in training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ use_instance_mask: whether to process instance segmentation annotations, if available
+ use_keypoint: whether to process keypoint annotations if available
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
+ masks into this format.
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
+ precomputed_proposal_topk: if given, will load pre-computed
+ proposals from dataset_dict and keep the top k proposals for each image.
+ recompute_boxes: whether to overwrite bounding box annotations
+ by computing tight bounding boxes from instance mask annotations.
+ """
+ if recompute_boxes:
+ assert use_instance_mask, "recompute_boxes requires instance masks"
+ # fmt: off
+ self.is_train = is_train
+ self.augmentations = T.AugmentationList(augmentations)
+ self.image_format = image_format
+ self.use_instance_mask = use_instance_mask
+ self.instance_mask_format = instance_mask_format
+ self.use_keypoint = use_keypoint
+ self.keypoint_hflip_indices = keypoint_hflip_indices
+ self.proposal_topk = precomputed_proposal_topk
+ self.recompute_boxes = recompute_boxes
+ # fmt: on
+ logger = logging.getLogger(__name__)
+ mode = "training" if is_train else "inference"
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
+
+ @classmethod
+ def from_config(cls, cfg, is_train: bool = True):
+ augs = utils.build_augmentation(cfg, is_train)
+ if cfg.INPUT.CROP.ENABLED and is_train:
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
+ recompute_boxes = cfg.MODEL.MASK_ON
else:
- print('Broken region')
- return image_sizes
-'''
+ recompute_boxes = False
+
+ ret = {
+ "is_train": is_train,
+ "augmentations": augs,
+ "image_format": cfg.INPUT.FORMAT,
+ "use_instance_mask": cfg.MODEL.MASK_ON,
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
+ "recompute_boxes": recompute_boxes,
+ }
+
+ if cfg.MODEL.KEYPOINT_ON:
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
+
+ if cfg.MODEL.LOAD_PROPOSALS:
+ ret["precomputed_proposal_topk"] = (
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
+ if is_train
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
+ )
+ return ret
+
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
+ # USER: Modify this if you want to keep them for some reason.
+ for anno in dataset_dict["annotations"]:
+ if not self.use_instance_mask:
+ anno.pop("segmentation", None)
+ if not self.use_keypoint:
+ anno.pop("keypoints", None)
+
+ annos = [
+ utils.transform_instance_annotations(
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
+ )
+ for obj in dataset_dict.pop('annotations')
+ if obj.get("iscrowd", 0) == 0
+ ]
+ instances = utils.annotations_to_instances(
+ annos, image_shape, mask_format=self.instance_mask_format
+ )
+
+ if self.recompute_boxes:
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ c=dataset_dict['coordinates']
+ h=dataset_dict['height']
+ w=dataset_dict['width']
+
+ slide= TiffSlide(dataset_dict['slide_loc'])
+ image=np.array(slide.read_region((c[0],c[1]),0,(h,w)))[:,:,:3]
+ slide.close()
+ maskData=xml_to_mask(dataset_dict['xml_loc'], c, [h,w])
+
+ if random.random()>0.5:
+ hShift=np.random.normal(0,0.05)
+ lShift=np.random.normal(1,0.025)
+ # imageblock[im]=randomHSVshift(imageblock[im],hShift,lShift)
+ image=rgb2hsv(image)
+ image[:,:,0]=(image[:,:,0]+hShift)
+ image=hsv2rgb(image)
+ image=rgb2lab(image)
+ image[:,:,0]=exposure.adjust_gamma(image[:,:,0],lShift)
+ image=(lab2rgb(image)*255).astype('uint8')
+ image = seq(images=[image])[0].squeeze()
+
+ dataset_dict['annotations']=mask2polygons(maskData)
+ utils.check_image_size(dataset_dict, image)
+
+ sem_seg_gt = maskData
+ sem_seg_gt[sem_seg_gt>2]=0
+ sem_seg_gt[maskData==0] = 3
+ sem_seg_gt=np.array(sem_seg_gt).astype('uint8')
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
+ transforms = self.augmentations(aug_input)
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
+
+ image_shape = image.shape[:2] # h, w
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+ if sem_seg_gt is not None:
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
+
+ if "annotations" in dataset_dict:
+ self._transform_annotations(dataset_dict, transforms, image_shape)
+
+
+ return dataset_dict
diff --git a/multic/segmentationschool/Codes/wsi_loader_utils.py b/multic/segmentationschool/Codes/wsi_loader_utils.py
index ea87ae6..0e1edb4 100644
--- a/multic/segmentationschool/Codes/wsi_loader_utils.py
+++ b/multic/segmentationschool/Codes/wsi_loader_utils.py
@@ -1,85 +1,85 @@
-import openslide,glob,os
+import glob, os
import numpy as np
from scipy.ndimage.morphology import binary_fill_holes
-import matplotlib.pyplot as plt
from skimage.color import rgb2hsv
from skimage.filters import gaussian
-# from skimage.morphology import binary_dilation, diamond
-# import cv2
from tqdm import tqdm
from skimage.io import imread,imsave
import multiprocessing
from joblib import Parallel, delayed
-
-def save_thumb(args,slide_loc):
- print(slide_loc)
- slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1])
- slide=openslide.OpenSlide(slide_loc)
- if slideExt =='.scn':
- dim_x=int(slide.properties['openslide.bounds-width'])## add to columns
- dim_y=int(slide.properties['openslide.bounds-height'])## add to rows
- offsetx=int(slide.properties['openslide.bounds-x'])##start column
- offsety=int(slide.properties['openslide.bounds-y'])##start row
- elif slideExt in ['.ndpi','.svs']:
- dim_x, dim_y=slide.dimensions
- offsetx=0
- offsety=0
-
- # fullSize=slide.level_dimensions[0]
- # resRatio= args.chop_thumbnail_resolution
- # ds_1=fullSize[0]/resRatio
- # ds_2=fullSize[1]/resRatio
- # thumbIm=np.array(slide.get_thumbnail((ds_1,ds_2)))
- # if slideExt =='.scn':
- # xStt=int(offsetx/resRatio)
- # xStp=int((offsetx+dim_x)/resRatio)
- # yStt=int(offsety/resRatio)
- # yStp=int((offsety+dim_y)/resRatio)
- # thumbIm=thumbIm[yStt:yStp,xStt:xStp]
- # imsave(slide_loc.replace(slideExt,'_thumb.jpeg'),thumbIm)
- slide.associated_images['label'].save(slide_loc.replace(slideExt,'_label.png'))
- # imsave(slide_loc.replace(slideExt,'_label.png'),slide.associated_images['label'])
-
-
-def get_image_thumbnails(args):
- assert args.target is not None, 'Location of images must be provided'
- all_slides=[]
- for ext in args.wsi_ext.split(','):
- all_slides.extend(glob.glob(args.target+'/*'+ext))
- Parallel(n_jobs=multiprocessing.cpu_count())(delayed(save_thumb)(args,slide_loc) for slide_loc in tqdm(all_slides))
- # for slide_loc in tqdm(all_slides):
-
-class WSIPredictLoader():
- def __init__(self,args, wsi_directory=None, transform=None):
+from shapely.geometry import Polygon
+from tiffslide import TiffSlide
+import random
+import glob
+import warnings
+from joblib import Parallel, delayed
+import multiprocessing
+from .xml_to_mask_minmax import write_minmax_to_xml
+import lxml.etree as ET
+
+def get_image_meta(i,args):
+ image_annotation_info={}
+ image_annotation_info['slide_loc']=i[0]
+ slide=TiffSlide(image_annotation_info['slide_loc'])
+ magx=np.round(float(slide.properties['tiffslide.mpp-x']),2)
+ magy=np.round(float(slide.properties['tiffslide.mpp-y']),2)
+
+ assert magx == magy
+ if magx ==0.25:
+ dx=args.boxSize
+ dy=args.boxSize
+ elif magx == 0.5:
+ dx=int(args.boxSize/2)
+ dy=int(args.boxSize/2)
+ else:
+ print('nonstandard image magnification')
+ print(slide)
+ print(magx,magy)
+ exit()
+
+ image_annotation_info['coordinates']=[i[2][1],i[2][0]]
+ image_annotation_info['height']=dx
+ image_annotation_info['width']=dy
+ image_annotation_info['image_id']=i[1].split('/')[-1].replace('.xml','_'.join(['',str(i[2][1]),str(i[2][0])]))
+ image_annotation_info['xml_loc']=i[1]
+ image_annotation_info['file_name']=i[1].split('/')[-1]
+ slide.close()
+ return image_annotation_info
+
+def train_samples_from_WSI(args,image_coordinates):
+
+
+ num_cores=multiprocessing.cpu_count()
+ print('Generating detectron2 dictionary format...')
+ data_list=Parallel(n_jobs=num_cores,backend='threading')(delayed(get_image_meta)(i=i,
+ args=args) for i in tqdm(image_coordinates))
+ return data_list
+
+def WSIGridIterator(wsi_name,choppable_regions,index_x,index_y,region_size,dim_x,dim_y):
+ wsi_name=os.path.splitext(wsi_name.split('/')[-1])[0]
+ data_list=[]
+ for idxy, i in tqdm(enumerate(index_y)):
+ for idxx, j in enumerate(index_x):
+ if choppable_regions[idxy, idxx] != 0:
+ yEnd = min(dim_y,i+region_size)
+ xEnd = min(dim_x,j+region_size)
+ xLen=xEnd-j
+ yLen=yEnd-i
+
+ image_annotation_info={}
+ image_annotation_info['file_name']='_'.join([wsi_name,str(j),str(i),str(xEnd),str(yEnd)])
+ image_annotation_info['height']=yLen
+ image_annotation_info['width']=xLen
+ image_annotation_info['image_id']=image_annotation_info['file_name']
+ image_annotation_info['xStart']=j
+ image_annotation_info['yStart']=i
+ data_list.append(image_annotation_info)
+ return data_list
+
+def get_slide_data(args, wsi_directory=None):
assert wsi_directory is not None, 'location of training svs and xml must be provided'
- mask_out_loc=os.path.join(wsi_directory.replace('/TRAINING_data/0','Permanent/Tissue_masks/'),)
- if not os.path.exists(mask_out_loc):
- os.makedirs(mask_out_loc)
- all_slides=[]
- for ext in args.wsi_ext.split(','):
- all_slides.extend(glob.glob(wsi_directory+'/*'+ext))
- print('Getting slide metadata and usable regions...')
- usable_slides=[]
- for slide_loc in all_slides:
- slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1])
- print("working slide... "+ slideID,end='\r')
-
- slide=openslide.OpenSlide(slide_loc)
- chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc)
- mag_x=np.round(float(slide.properties['openslide.mpp-x']),2)
- mag_y=np.round(float(slide.properties['openslide.mpp-y']),2)
- print(mag_x,mag_y)
- usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,'slideExt':slideExt,'slide':slide,
- 'chop_array':chop_array,'mag':[mag_x,mag_y]})
- self.usable_slides= usable_slides
- self.boxSize40X = args.boxSize
- self.boxSize20X = int(args.boxSize)/2
-
-class WSITrainingLoader():
- def __init__(self,args, wsi_directory=None, transform=None):
- assert wsi_directory is not None, 'location of training svs and xml must be provided'
- mask_out_loc=os.path.join(wsi_directory.replace('/TRAINING_data/0','Permanent/Tissue_masks/'),)
+ mask_out_loc=os.path.join(wsi_directory, 'Tissue_masks')
if not os.path.exists(mask_out_loc):
os.makedirs(mask_out_loc)
all_slides=[]
@@ -90,34 +90,102 @@ def __init__(self,args, wsi_directory=None, transform=None):
usable_slides=[]
for slide_loc in all_slides:
slideID,slideExt=os.path.splitext(slide_loc.split('/')[-1])
- print("working slide... "+ slideID,end='\r')
-
- slide=openslide.OpenSlide(slide_loc)
- chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc)
- mag_x=np.round(float(slide.properties['openslide.mpp-x']),2)
- mag_y=np.round(float(slide.properties['openslide.mpp-y']),2)
- print(mag_x,mag_y)
- usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,'slideExt':slideExt,'slide':slide,
- 'chop_array':chop_array,'mag':[mag_x,mag_y]})
- self.usable_slides= usable_slides
- self.boxSize40X = args.boxSize
- self.boxSize20X = int(args.boxSize)/2
+ xmlpath=slide_loc.replace(slideExt,'.xml')
+ if os.path.isfile(xmlpath):
+ write_minmax_to_xml(xmlpath)
+
+ print("Gathering slide data ... "+ slideID,end='\r')
+ slide =TiffSlide(slide_loc)
+ chop_array=get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc)
+
+ mag_x=np.round(float(slide.properties['tiffslide.mpp-x']),2)
+ mag_y=np.round(float(slide.properties['tiffslide.mpp-y']),2)
+ slide.close()
+ tree = ET.parse(xmlpath)
+ root = tree.getroot()
+ balance_classes=args.balanceClasses.split(',')
+ classNums={}
+ for b in balance_classes:
+ classNums[b]=0
+ # balance_annotations={}
+ for Annotation in root.findall("./Annotation"):
+
+ annotationID = Annotation.attrib['Id']
+ if annotationID=='7':
+ print(xmlpath)
+ exit()
+ if annotationID in classNums.keys():
+
+ classNums[annotationID]=len(Annotation.findall("./*/Region"))
+ else:
+ pass
+
+ usable_slides.append({'slide_loc':slide_loc,'slideID':slideID,
+ 'chop_array':chop_array,'num_regions':len(chop_array),'mag':[mag_x,mag_y],
+ 'xml_loc':xmlpath,'annotations':classNums,'root':root
+ })
+ else:
+ print('\n')
+ print('no annotation XML file found for:')
+ print(slideID)
+ exit()
print('\n')
+ return usable_slides
+
+def get_random_chops(slide_idx,usable_slides,region_size):
+ # chops=[]
+ choplen=len(slide_idx)
+ chops=Parallel(n_jobs=multiprocessing.cpu_count(),backend='threading')(delayed(get_chop_data)(idx=idx,
+ usable_slides=usable_slides,region_size=region_size) for idx in tqdm(slide_idx))
+ return chops
+
+
+def get_chop_data(idx,usable_slides,region_size):
+ if random.random()>0.5:
+ randSelect=random.randrange(0,usable_slides[idx]['num_regions'])
+ chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'],
+ usable_slides[idx]['chop_array'][randSelect]]
+ else:
+ # print(list(usable_slides[idx]['annotations'].values()))
+ if sum(usable_slides[idx]['annotations'].values())==0:
+ randSelect=random.randrange(0,usable_slides[idx]['num_regions'])
+ chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'],
+ usable_slides[idx]['chop_array'][randSelect]]
+ else:
+ classIDs=list(usable_slides[idx]['annotations'].keys())
+ classSamples=random.sample(classIDs,len(classIDs))
+ for c in classSamples:
+ if usable_slides[idx]['annotations'][c]==0 or c == '5':
+ continue
+ else:
+ sampledRegionID=random.randrange(1,usable_slides[idx]['annotations'][c]+1)
+
+ break
+
+
+ Verts = usable_slides[idx]['root'].findall("./Annotation[@Id='{}']/Regions/Region[@Id='{}']/Vertices/Vertex".format(c,sampledRegionID))
+ centroid = (Polygon([(int(float(k.attrib['X'])),int(float(k.attrib['Y']))) for k in Verts]).centroid)
+
+ randVertX=int(centroid.x)-region_size//2
+ randVertY=int(centroid.y)-region_size//2
+
+ chopData=[usable_slides[idx]['slide_loc'],usable_slides[idx]['xml_loc'],
+ [randVertY,randVertX]]
+
+ return chopData
def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc):
slide_regions=[]
choppable_regions_list=[]
-
downsample = int(args.downsampleRate**.5) #down sample for each dimension
region_size = int(args.boxSize*(downsample)) #Region size before downsampling
- step = int(region_size*(1-args.overlap_percent)) #Step size before downsampling
-
+ step = int(region_size*(1-args.overlap_rate)) #Step size before downsampling
if slideExt =='.scn':
- dim_x=int(slide.properties['openslide.bounds-width'])## add to columns
- dim_y=int(slide.properties['openslide.bounds-height'])## add to rows
- offsetx=int(slide.properties['openslide.bounds-x'])##start column
- offsety=int(slide.properties['openslide.bounds-y'])##start row
+ dim_x=int(slide.properties['tiffslide.bounds-width'])## add to columns
+ dim_y=int(slide.properties['tiffslide.bounds-height'])## add to rows
+ offsetx=int(slide.properties['tiffslide.bounds-x'])##start column
+ offsety=int(slide.properties['tiffslide.bounds-y'])##start row
index_y=np.array(range(offsety,offsety+dim_y,step))
index_x=np.array(range(offsetx,offsetx+dim_x,step))
index_y[-1]=(offsety+dim_y)-step
@@ -135,7 +203,9 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc):
resRatio= args.chop_thumbnail_resolution
ds_1=fullSize[0]/resRatio
ds_2=fullSize[1]/resRatio
- if args.get_new_tissue_masks:
+ out_mask_name=os.path.join(mask_out_loc,'_'.join([slideID,slideExt[1:]+'.png']))
+ if not os.path.isfile(out_mask_name) or args.get_new_tissue_masks:
+ print(out_mask_name)
thumbIm=np.array(slide.get_thumbnail((ds_1,ds_2)))
if slideExt =='.scn':
xStt=int(offsetx/resRatio)
@@ -143,40 +213,18 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc):
yStt=int(offsety/resRatio)
yStp=int((offsety+dim_y)/resRatio)
thumbIm=thumbIm[yStt:yStp,xStt:xStp]
- # plt.imshow(thumbIm)
- # plt.show()
- # input()
- # plt.imshow(thumbIm)
- # plt.show()
-
- out_mask_name=os.path.join(mask_out_loc,'_'.join([slideID,slideExt[1:]+'.png']))
-
-
- if not args.get_new_tissue_masks:
- try:
- binary=(imread(out_mask_name)/255).astype('bool')
- except:
- print('failed to load mask for '+ out_mask_name)
- print('please set get_new_tissue masks to True')
- exit()
- # if slideExt =='.scn':
- # choppable_regions=np.zeros((len(index_x),len(index_y)))
- # elif slideExt in ['.ndpi','.svs']:
- choppable_regions=np.zeros((len(index_y),len(index_x)))
- else:
- print(out_mask_name)
- # if slideExt =='.scn':
- # choppable_regions=np.zeros((len(index_x),len(index_y)))
- # elif slideExt in ['.ndpi','.svs']:
choppable_regions=np.zeros((len(index_y),len(index_x)))
-
hsv=rgb2hsv(thumbIm)
g=gaussian(hsv[:,:,1],5)
binary=(g>0.05).astype('bool')
binary=binary_fill_holes(binary)
imsave(out_mask_name.replace('.png','.jpeg'),thumbIm)
- imsave(out_mask_name,binary.astype('uint8')*255)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ imsave(out_mask_name,binary.astype('uint8')*255)
+ binary=(imread(out_mask_name)/255).astype('bool')
+ choppable_regions=np.zeros((len(index_y),len(index_x)))
chop_list=[]
for idxy,yi in enumerate(index_y):
for idxx,xj in enumerate(index_x):
@@ -185,31 +233,13 @@ def get_choppable_regions(slide,args,slideID,slideExt,mask_out_loc):
xStart = int(np.round((xj-offsetx)/resRatio))
xStop = int(np.round(((xj-offsetx)+args.boxSize)/resRatio))
box_total=(xStop-xStart)*(yStop-yStart)
- if slideExt =='.scn':
- # print(xStart,xStop,yStart,yStop)
- # print(np.sum(binary[xStart:xStop,yStart:yStop]),args.white_percent,box_total)
- # plt.imshow(binary[xStart:xStop,yStart:yStop])
- # plt.show()
- if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total):
-
- choppable_regions[idxy,idxx]=1
- chop_list.append([index_y[idxy],index_x[idxx]])
-
- elif slideExt in ['.ndpi','.svs']:
- if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total):
- choppable_regions[idxy,idxx]=1
- chop_list.append([index_y[idxy],index_x[idxx]])
-
- imsave(out_mask_name.replace('.png','_chopregions.png'),choppable_regions.astype('uint8')*255)
-
- # plt.imshow(choppable_regions)
- # plt.show()
- # choppable_regions_list.extend(chop_list)
- # plt.subplot(131)
- # plt.imshow(thumbIm)
- # plt.subplot(132)
- # plt.imshow(binary)
- # plt.subplot(133)
- # plt.imshow(choppable_regions)
- # plt.show()
+
+ if np.sum(binary[yStart:yStop,xStart:xStop])>(args.white_percent*box_total):
+ choppable_regions[idxy,idxx]=1
+ chop_list.append([index_y[idxy],index_x[idxx]])
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ imsave(out_mask_name.replace('.png','_chopregions.png'),choppable_regions.astype('uint8')*255)
+
return chop_list
diff --git a/multic/segmentationschool/segmentation_school.py b/multic/segmentationschool/segmentation_school.py
index 50e32c1..a0d5b42 100644
--- a/multic/segmentationschool/segmentation_school.py
+++ b/multic/segmentationschool/segmentation_school.py
@@ -174,7 +174,7 @@ def savetime(args, starttime):
##### Args for training / prediction ####################################################
parser.add_argument('--gpu_num', dest='gpu_num', default=2 ,type=int,
help='number of GPUs avalable')
- parser.add_argument('--gpu', dest='gpu', default=0 ,type=int,
+ parser.add_argument('--gpu', dest='gpu', default="" ,type=str,
help='GPU to use for prediction')
parser.add_argument('--iteration', dest='iteration', default='none' ,type=str,
help='Which iteration to use for prediction')
@@ -188,6 +188,18 @@ def savetime(args, starttime):
help='number of classes present in the High res training data [USE ONLY IF DIFFERENT FROM LOW RES]')
parser.add_argument('--modelfile', dest='modelfile', default=None ,type=str,
help='the desired model file to use for training or prediction')
+ parser.add_argument('--init_modelfile', dest='init_modelfile', default=None ,type=str,
+ help='the desired model file to use for training or prediction')
+ parser.add_argument('--eval_period', dest='eval_period', default=1000 ,type=int,
+ help='Validation Period')
+ parser.add_argument('--batch_size', dest='batch_size', default=4 ,type=int,
+ help='Size of batches for training high resolution CNN')
+ parser.add_argument('--train_steps', dest='train_steps', default=1000 ,type=int,
+ help='Size of batches for training high resolution CNN')
+ parser.add_argument('--training_data_dir', dest='training_data_dir', default=os.getcwd(),type=str,
+ help='Training Data Folder')
+ parser.add_argument('--overlap_rate', dest='overlap_rate', default=0.5 ,type=float,
+ help='overlap percentage of high resolution blocks [0-1]')
### Params for cutting wsi ###
#White level cutoff
@@ -202,10 +214,14 @@ def savetime(args, starttime):
help='size of low resolution blocks')
parser.add_argument('--downsampleRateLR', dest='downsampleRateLR', default=16 ,type=int,
help='reduce image resolution to 1/downsample rate')
+ parser.add_argument('--get_new_tissue_masks', dest='get_new_tissue_masks', default=False,type=str2bool,
+ help="Don't load usable tisse regions from disk, create new ones")
+ parser.add_argument('--downsampleRate', dest='downsampleRate', default=1 ,type=int,
+ help='reduce image resolution to 1/downsample rate')
#High resolution parameters
parser.add_argument('--overlap_percentHR', dest='overlap_percentHR', default=0 ,type=float,
help='overlap percentage of high resolution blocks [0-1]')
- parser.add_argument('--boxSize', dest='boxSize', default=2048 ,type=int,
+ parser.add_argument('--boxSize', dest='boxSize', default=1200 ,type=int,
help='size of high resolution blocks')
parser.add_argument('--downsampleRateHR', dest='downsampleRateHR', default=1 ,type=int,
help='reduce image resolution to 1/downsample rate')
@@ -226,7 +242,8 @@ def savetime(args, starttime):
help='Gaussian variance defining bounds on Hue shift for HSV color augmentation')
parser.add_argument('--lbound', dest='lbound', default=0.025 ,type=float,
help='Gaussian variance defining bounds on L* gamma shift for color augmentation [alters brightness/darkness of image]')
-
+ parser.add_argument('--balanceClasses', dest='balanceClasses', default='3,4,5,6',type=str,
+ help="which classes to balance during training")
### Params for training networks ###
#Low resolution hyperparameters
parser.add_argument('--CNNbatch_sizeLR', dest='CNNbatch_sizeLR', default=2 ,type=int,
@@ -280,6 +297,8 @@ def savetime(args, starttime):
help='padded region for low resolution region extraction')
parser.add_argument('--show_interstitium', dest='show_interstitium', default=True ,type=str2bool,
help='padded region for low resolution region extraction')
+ parser.add_argument('--num_workers', dest='num_workers', default=1 ,type=int,
+ help='Number of workers for data loader')
diff --git a/multic/segmentationschool/utils/mask_to_xml.py b/multic/segmentationschool/utils/mask_to_xml.py
new file mode 100644
index 0000000..34217e5
--- /dev/null
+++ b/multic/segmentationschool/utils/mask_to_xml.py
@@ -0,0 +1,135 @@
+import cv2
+import numpy as np
+import lxml.etree as ET
+
+"""
+xml_path (string) - the filename of the saved xml
+mask (array) - the mask to convert to xml - uint8 array
+downsample (int) - amount of downsampling done to the mask
+ points are upsampled - this can be used to simplify the mask
+min_size_thresh (int) - the minimum objectr size allowed in the mask. This is referenced from downsample=1
+xml_color (list) - list of binary color values to be used for classes
+
+"""
+
+def mask_to_xml(xml_path, mask, downsample=1, min_size_thresh=0, simplify_contours=0, xml_color=[65280, 65535, 33023, 255, 16711680], verbose=0, return_root=False, maxClass=None, offset={'X': 0,'Y': 0}):
+
+ min_size_thresh /= downsample
+
+ # create xml tree
+ Annotations = xml_create()
+
+ # get all classes
+ classes = np.unique(mask)
+ if maxClass is None:
+ maxClass = max(classes)
+
+ # add annotation classes to tree
+ for class_ in range(maxClass+1)[1:]:
+ if verbose:
+ print('Creating class: [{}]'.format(class_))
+ Annotations = xml_add_annotation(Annotations=Annotations, xml_color=xml_color, annotationID=class_)
+
+ # add contour points to tree classwise
+ for class_ in classes: # iterate through all classes
+
+ if class_ == 0 or class_ > maxClass:
+ continue
+
+ if verbose:
+ print('Working on class [{} of {}]'.format(class_, max(classes)))
+
+ # binarize the mask w.r.t. class_
+ binaryMask = mask==class_
+
+ # get contour points of the mask
+ pointsList = get_contour_points(binaryMask, downsample=downsample, min_size_thresh=min_size_thresh, simplify_contours=simplify_contours, offset=offset)
+ for i in range(np.shape(pointsList)[0]):
+ pointList = pointsList[i]
+ Annotations = xml_add_region(Annotations=Annotations, pointList=pointList, annotationID=class_)
+
+ if return_root:
+ # return root, do not save xml file
+ return Annotations
+
+ # save the final xml file
+ xml_save(Annotations=Annotations, filename='{}.xml'.format(xml_path.split('.')[0]))
+
+
+def get_contour_points(mask, downsample, min_size_thresh=0, simplify_contours=0, offset={'X': 0,'Y': 0}):
+ # returns a dict pointList with point 'X' and 'Y' values
+ # input greyscale binary image
+ #_, maskPoints, contours = cv2.findContours(np.array(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
+ maskPoints, contours = cv2.findContours(np.uint8(mask), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1)
+ maskPoints = list(maskPoints)
+ # remove small regions
+ too_small = []
+ for idx, cnt in enumerate(maskPoints):
+ area = cv2.contourArea(cnt)
+ if area < min_size_thresh:
+ too_small.append(idx)
+ if too_small != []:
+ too_small.reverse()
+ for idx in too_small:
+ maskPoints.pop(idx)
+
+ if simplify_contours > 0:
+ for idx, cnt in enumerate(maskPoints):
+ epsilon = simplify_contours*cv2.arcLength(cnt,True)
+ approx = cv2.approxPolyDP(cnt,epsilon,True)
+ maskPoints[idx] = approx
+
+ pointsList = []
+ for j in range(np.shape(maskPoints)[0]):
+ pointList = []
+ for i in range(0,np.shape(maskPoints[j])[0]):
+ point = {'X': (maskPoints[j][i][0][0] * downsample) + offset['X'], 'Y': (maskPoints[j][i][0][1] * downsample) + offset['Y']}
+ pointList.append(point)
+ pointsList.append(pointList)
+ return pointsList
+
+### functions for building an xml tree of annotations ###
+def xml_create(): # create new xml tree
+ # create new xml Tree - Annotations
+ Annotations = ET.Element('Annotations')
+ return Annotations
+
+def xml_add_annotation(Annotations, xml_color, annotationID=None): # add new annotation
+ # add new Annotation to Annotations
+ # defualts to new annotationID
+ if annotationID == None: # not specified
+ annotationID = len(Annotations.findall('Annotation')) + 1
+ Annotation = ET.SubElement(Annotations, 'Annotation', attrib={'Type': '4', 'Visible': '1', 'ReadOnly': '0', 'Incremental': '0', 'LineColorReadOnly': '0', 'LineColor': str(xml_color[annotationID-1]), 'Id': str(annotationID), 'NameReadOnly': '0'})
+ Regions = ET.SubElement(Annotation, 'Regions')
+ return Annotations
+
+def xml_add_region(Annotations, pointList, annotationID=-1, regionID=None): # add new region to annotation
+ # add new Region to Annotation
+ # defualts to last annotationID and new regionID
+ Annotation = Annotations.find("Annotation[@Id='" + str(annotationID) + "']")
+ Regions = Annotation.find('Regions')
+ if regionID == None: # not specified
+ regionID = len(Regions.findall('Region')) + 1
+ Region = ET.SubElement(Regions, 'Region', attrib={'NegativeROA': '0', 'ImageFocus': '-1', 'DisplayId': '1', 'InputRegionId': '0', 'Analyze': '0', 'Type': '0', 'Id': str(regionID)})
+ Vertices = ET.SubElement(Region, 'Vertices')
+ for point in pointList: # add new Vertex
+ ET.SubElement(Vertices, 'Vertex', attrib={'X': str(point['X']), 'Y': str(point['Y']), 'Z': '0'})
+ # add connecting point
+ ET.SubElement(Vertices, 'Vertex', attrib={'X': str(pointList[0]['X']), 'Y': str(pointList[0]['Y']), 'Z': '0'})
+ return Annotations
+
+def xml_save(Annotations, filename):
+ xml_data = ET.tostring(Annotations, pretty_print=True)
+ #xml_data = Annotations.toprettyxml()
+ f = open(filename, 'w')
+ f.write(xml_data.decode())
+ f.close()
+
+def read_xml(filename):
+ # import xml file
+ tree = ET.parse(filename)
+ root = tree.getroot()
+
+if __name__ == '__main__':
+ main()
+
\ No newline at end of file
diff --git a/multic/segmentationschool/utils/xml_to_mask.py b/multic/segmentationschool/utils/xml_to_mask.py
new file mode 100644
index 0000000..aeecd6a
--- /dev/null
+++ b/multic/segmentationschool/utils/xml_to_mask.py
@@ -0,0 +1,219 @@
+import numpy as np
+import lxml.etree as ET
+import cv2
+import time
+import os
+
+"""
+location (tuple) - (x, y) tuple giving the top left pixel in the level 0 reference frame
+size (tuple) - (width, height) tuple giving the region size | set to 'full' for entire mask
+downsample - int giving the amount of downsampling done to the output pixel mask
+
+NOTE: if you plan to loop through xmls parallely, it is nessesary to run write_minmax_to_xml()
+ on all the files prior - to avoid conflicting file writes
+
+"""
+
+def xml_to_mask(xml_path, location, size, tree=None, downsample=1, verbose=0):
+
+ # parse xml and get root
+ if tree == None: tree = ET.parse(xml_path)
+ root = tree.getroot()
+
+ if size == 'full':
+ import math
+ size = write_minmax_to_xml(xml_path=xml_path, tree=tree, get_absolute_max=True)
+ size = (math.ceil(size[0]/downsample), math.ceil(size[1]/downsample))
+ location = (0,0)
+
+ # calculate region bounds
+ bounds = {'x_min' : location[0], 'y_min' : location[1], 'x_max' : location[0] + size[0]*downsample, 'y_max' : location[1] + size[1]*downsample}
+
+ IDs = regions_in_mask(xml_path=xml_path, root=root, tree=tree, bounds=bounds, verbose=verbose)
+
+ if verbose != 0:
+ print('\nFOUND: ' + str(len(IDs)) + ' regions')
+
+ # find regions in bounds
+ Regions = get_vertex_points(root=root, IDs=IDs, verbose=verbose)
+
+ # fill regions and create mask
+ mask = Regions_to_mask(Regions=Regions, bounds=bounds, IDs=IDs, downsample=downsample, verbose=verbose)
+ if verbose != 0:
+ print('done...\n')
+
+ return mask
+
+
+def regions_in_mask(xml_path, root, tree, bounds, verbose=1):
+ # find regions to save
+ IDs = []
+ mtime = os.path.getmtime(xml_path)
+
+ write_minmax_to_xml(xml_path, tree)
+
+ for Annotation in root.findall("./Annotation"): # for all annotations
+ annotationID = Annotation.attrib['Id']
+
+ for Region in Annotation.findall("./*/Region"): # iterate on all region
+
+ for Vert in Region.findall("./Vertices"): # iterate on all vertex in region
+
+ # get minmax points
+ Xmin = np.int32(Vert.attrib['Xmin'])
+ Ymin = np.int32(Vert.attrib['Ymin'])
+ Xmax = np.int32(Vert.attrib['Xmax'])
+ Ymax = np.int32(Vert.attrib['Ymax'])
+
+ # test minmax points in region bounds
+ if bounds['x_min'] <= Xmax and bounds['x_max'] >= Xmin and bounds['y_min'] <= Ymax and bounds['y_max'] >= Ymin:
+ # save region Id
+ IDs.append({'regionID' : Region.attrib['Id'], 'annotationID' : annotationID})
+ break
+ return IDs
+
+def get_vertex_points(root, IDs, verbose=1):
+ Regions = []
+
+ for ID in IDs: # for all IDs
+
+ # get all vertex attributes (points)
+ Vertices = []
+
+ for Vertex in root.findall("./Annotation[@Id='" + ID['annotationID'] + "']/Regions/Region[@Id='" + ID['regionID'] + "']/Vertices/Vertex"):
+ # make array of points
+ Vertices.append([int(float(Vertex.attrib['X'])), int(float(Vertex.attrib['Y']))])
+
+ Regions.append(np.array(Vertices))
+
+ return Regions
+
+def Regions_to_mask(Regions, bounds, IDs, downsample, verbose=1):
+ # downsample = int(np.round(downsample_factor**(.5)))
+
+ if verbose !=0:
+ print('\nMAKING MASK:')
+
+ if len(Regions) != 0: # regions present
+ # get min/max sizes
+ min_sizes = np.empty(shape=[2,0], dtype=np.int32)
+ max_sizes = np.empty(shape=[2,0], dtype=np.int32)
+ for Region in Regions: # fill all regions
+ min_bounds = np.reshape((np.amin(Region, axis=0)), (2,1))
+ max_bounds = np.reshape((np.amax(Region, axis=0)), (2,1))
+ min_sizes = np.append(min_sizes, min_bounds, axis=1)
+ max_sizes = np.append(max_sizes, max_bounds, axis=1)
+ min_size = np.amin(min_sizes, axis=1)
+ max_size = np.amax(max_sizes, axis=1)
+
+ # add to old bounds
+ bounds['x_min_pad'] = min(min_size[1], bounds['x_min'])
+ bounds['y_min_pad'] = min(min_size[0], bounds['y_min'])
+ bounds['x_max_pad'] = max(max_size[1], bounds['x_max'])
+ bounds['y_max_pad'] = max(max_size[0], bounds['y_max'])
+
+ # make blank mask
+ mask = np.zeros([ int(np.round((bounds['y_max_pad'] - bounds['y_min_pad']) / downsample)), int(np.round((bounds['x_max_pad'] - bounds['x_min_pad']) / downsample)) ], dtype=np.uint8)
+
+ # fill mask polygons
+ index = 0
+ for Region in Regions:
+ # reformat Regions
+ Region[:,1] = np.int32(np.round((Region[:,1] - bounds['y_min_pad']) / downsample))
+ Region[:,0] = np.int32(np.round((Region[:,0] - bounds['x_min_pad']) / downsample))
+ # get annotation ID for mask color
+ ID = IDs[index]
+ cv2.fillPoly(mask, [Region], int(ID['annotationID']))
+ index = index + 1
+
+ # reshape mask
+ x_start = np.int32(np.round((bounds['x_min'] - bounds['x_min_pad']) / downsample))
+ y_start = np.int32(np.round((bounds['y_min'] - bounds['y_min_pad']) / downsample))
+ x_stop = np.int32(np.round((bounds['x_max'] - bounds['x_min_pad']) / downsample))
+ y_stop = np.int32(np.round((bounds['y_max'] - bounds['y_min_pad']) / downsample))
+ # pull center mask region
+ mask = mask[ y_start:y_stop, x_start:x_stop ]
+
+ else: # no Regions
+ mask = np.zeros([ int(np.round((bounds['y_max'] - bounds['y_min']) / downsample)), int(np.round((bounds['x_max'] - bounds['x_min']) / downsample)) ], dtype=np.uint8)
+
+ return mask
+
+def write_minmax_to_xml(xml_path, tree=None, time_buffer=10, get_absolute_max=False):
+ # function to write min and max verticies to each region
+
+ # parse xml and get root
+ if tree == None: tree = ET.parse(xml_path)
+ root = tree.getroot()
+
+ try:
+ if get_absolute_max:
+ # break the try statement
+ X_max = 0
+ Y_max = 0
+ raise ValueError
+
+ # has the xml been modified to include minmax
+ modtime = np.float64(root.attrib['modtime'])
+ # has the minmax modified xml been changed?
+ assert os.path.getmtime(xml_path) < modtime + time_buffer
+
+ except:
+
+ for Annotation in root.findall("./Annotation"): # for all annotations
+ annotationID = Annotation.attrib['Id']
+
+ for Region in Annotation.findall("./*/Region"): # iterate on all region
+
+ for Vert in Region.findall("./Vertices"): # iterate on all vertex in region
+ Xs = []
+ Ys = []
+ for Vertex in Vert.findall("./Vertex"): # iterate on all vertex in region
+ # get points
+ Xs.append(np.int32(np.float64(Vertex.attrib['X'])))
+ Ys.append(np.int32(np.float64(Vertex.attrib['Y'])))
+
+ # find min and max points
+ Xs = np.array(Xs)
+ Ys = np.array(Ys)
+
+ if get_absolute_max:
+ # get the biggest point in annotation
+ if Xs != [] and Ys != []:
+ X_max = max(X_max, np.max(Xs))
+ Y_max = max(Y_max, np.max(Ys))
+
+ else:
+ # modify the xml
+ Vert.set("Xmin", "{}".format(np.min(Xs)))
+ Vert.set("Xmax", "{}".format(np.max(Xs)))
+ Vert.set("Ymin", "{}".format(np.min(Ys)))
+ Vert.set("Ymax", "{}".format(np.max(Ys)))
+
+ if get_absolute_max:
+ # return annotation max point
+ return (X_max,Y_max)
+
+ else:
+ # modify the xml with minmax region info
+ root.set("modtime", "{}".format(time.time()))
+ xml_data = ET.tostring(tree, pretty_print=True)
+ #xml_data = Annotations.toprettyxml()
+ f = open(xml_path, 'w')
+ f.write(xml_data.decode())
+ f.close()
+
+
+def get_num_classes(xml_path,ignore_label=None):
+ # parse xml and get root
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+
+ annotation_num = 0
+ for Annotation in root.findall("./Annotation"): # for all annotations
+ if ignore_label != None:
+ if not int(Annotation.attrib['Id']) == ignore_label:
+ annotation_num += 1
+ else: annotation_num += 1
+
+ return annotation_num + 1
diff --git a/setup.py b/setup.py
index fff61e6..3581056 100644
--- a/setup.py
+++ b/setup.py
@@ -45,12 +45,12 @@ def prerelease_local_scheme(version):
install_requires=[
# scientific packages
'nimfa>=1.3.2',
- 'numpy>=1.21.1',
+ 'numpy==1.23.5',
'scipy>=0.19.0',
'Pillow==9.5.0',
'pandas>=0.19.2',
'imageio>=2.3.0',
- # 'shapely[vectorized]',
+ 'shapely',
#'opencv-python-headless<4.7',
#'sqlalchemy',
# 'matplotlib',
@@ -69,6 +69,7 @@ def prerelease_local_scheme(version):
# 'umap-learn==0.5.3',
'openpyxl',
'xlrd<2',
+ 'imgaug',
# dask packages
'dask[dataframe]>=1.1.0',
'distributed>=1.21.6',