R/createResUnetModel.R
createResUnetModel2D.Rd
Creates a keras model of the U-net + ResNet deep learning architecture for image segmentation and regression with the paper available here:
createResUnetModel2D( inputImageSize, numberOfOutputs = 1, numberOfFiltersAtBaseLayer = 32, bottleNeckBlockDepthSchedule = c(3, 4), convolutionKernelSize = c(3, 3), deconvolutionKernelSize = c(2, 2), dropoutRate = 0, weightDecay = 0.0001, mode = c("classification", "regression") )
inputImageSize | Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). The batch size (i.e., number of training images) is not specified a priori. |
---|---|
numberOfOutputs | Meaning depends on the |
numberOfFiltersAtBaseLayer | number of filters at the beginning and end
of the |
bottleNeckBlockDepthSchedule | vector that provides the encoding layer schedule for the number of bottleneck blocks per long skip connection. |
convolutionKernelSize | 2-d vector defining the kernel size during the encoding path |
deconvolutionKernelSize | 2-d vector defining the kernel size during the decoding |
dropoutRate | float between 0 and 1 to use between dense layers. |
weightDecay | weighting parameter for L2 regularization of the kernel weights of the convolution layers. Default = 0.0. |
mode | 'classification' or 'regression'. |
a res/u-net keras model
\url{https://arxiv.org/abs/1608.04117}
This particular implementation was ported from the following python implementation:
\url{https://github.com/veugene/fcn_maker/}
Tustison NJ
library( ANTsR ) library( ANTsRNet ) library( keras ) imageIDs <- c( "r16", "r27", "r30", "r62", "r64", "r85" ) trainingBatchSize <- length( imageIDs ) # Perform simple 3-tissue segmentation. segmentationLabels <- c( 1, 2, 3 ) numberOfLabels <- length( segmentationLabels ) initialization <- paste0( 'KMeans[', numberOfLabels, ']' ) domainImage <- antsImageRead( getANTsRData( imageIDs[1] ) ) X_train <- array( data = NA, dim = c( trainingBatchSize, dim( domainImage ), 1 ) ) Y_train <- array( data = NA, dim = c( trainingBatchSize, dim( domainImage ) ) ) images <- list() segmentations <- list() for( i in seq_len( trainingBatchSize ) ) { cat( "Processing image", imageIDs[i], "\n" ) image <- antsImageRead( getANTsRData( imageIDs[i] ) ) mask <- getMask( image ) segmentation <- atropos( image, mask, initialization )$segmentation X_train[i,,, 1] <- as.array( image ) Y_train[i,,] <- as.array( segmentation ) }#> Processing image r16 #> Processing image r27 #> Processing image r30 #> Processing image r62 #> Processing image r64 #> Processing image r85Y_train <- encodeUnet( Y_train, segmentationLabels ) # Perform a simple normalization X_train <- ( X_train - mean( X_train ) ) / sd( X_train ) # Create the model model <- createResUnetModel2D( c( dim( domainImage ), 1 ), numberOfOutputs = numberOfLabels )#> Error in py_discover_config(required_module, use_environment): Python specified in RETICULATE_PYTHON (/Users/ntustison/anaconda3/envs/antsx/bin/python3) does not exist#> used (Mb) gc trigger (Mb) limit (Mb) max used (Mb) #> Ncells 2524849 134.9 4570014 244.1 NA 4570014 244.1 #> Vcells 6025120 46.0 12255594 93.6 65536 12157326 92.8metric_multilabel_dice_coefficient <- custom_metric( "multilabel_dice_coefficient", multilabel_dice_coefficient )#> Error in value[[3L]](cond): The R function's signature must not contains esoteric Python-incompatible constructs. Detailed traceback: #> Python specified in RETICULATE_PYTHON (/Users/ntustison/anaconda3/envs/antsx/bin/python3) does not existloss_dice <- function( y_true, y_pred ) { -multilabel_dice_coefficient(y_true, y_pred) } attr(loss_dice, "py_function_name") <- "multilabel_dice_coefficient" model %>% compile( loss = loss_dice, optimizer = optimizer_adam( lr = 0.0001 ), metrics = c( metric_multilabel_dice_coefficient, metric_categorical_crossentropy ) )#> Error in compile(., loss = loss_dice, optimizer = optimizer_adam(lr = 1e-04), metrics = c(metric_multilabel_dice_coefficient, metric_categorical_crossentropy)): object 'model' not found# Comment out the rest due to travis build constraints # Fit the model # track <- model %>% fit( X_train, Y_train, # epochs = 100, batch_size = 4, verbose = 1, shuffle = TRUE, # callbacks = list( # callback_model_checkpoint( "resUnetModelInterimWeights.h5", # monitor = 'val_loss', save_best_only = TRUE ), # callback_reduce_lr_on_plateau( monitor = "val_loss", factor = 0.1 ) # ), # validation_split = 0.2 ) rm(X_train); gc()#> used (Mb) gc trigger (Mb) limit (Mb) max used (Mb) #> Ncells 2524994 134.9 4570014 244.1 NA 4570014 244.1 #> Vcells 5632061 43.0 12255594 93.6 65536 12157326 92.8#> used (Mb) gc trigger (Mb) limit (Mb) max used (Mb) #> Ncells 2524986 134.9 4570014 244.1 NA 4570014 244.1 #> Vcells 4452422 34.0 12255594 93.6 65536 12157326 92.8# Save the model and/or save the model weights # save_model_hdf5( model, filepath = 'resUnetModel.h5' ) # save_model_weights_hdf5( unetModel, filepath = 'resUnetModelWeights.h5' ) ) rm(model); gc()#> Warning: object 'model' not found#> used (Mb) gc trigger (Mb) limit (Mb) max used (Mb) #> Ncells 2525023 134.9 4570014 244.1 NA 4570014 244.1 #> Vcells 4452475 34.0 12255594 93.6 65536 12157326 92.8