R/createDenseUnetModel.R
createDenseUnetModel2D.Rd
Creates a keras model of the dense U-net deep learning architecture for image segmentation
createDenseUnetModel2D( inputImageSize, numberOfOutputs = 1L, numberOfLayersPerDenseBlock = c(6, 12, 36, 24), growthRate = 48, initialNumberOfFilters = 96, reductionRate = 0, depth = 7, dropoutRate = 0, weightDecay = 0.0001, mode = c("classification", "regression") )
inputImageSize | Used for specifying the input tensor shape. The shape (or dimension) of that tensor is the image dimensions followed by the number of channels (e.g., red, green, and blue). |
---|---|
numberOfOutputs | Meaning depends on the |
numberOfLayersPerDenseBlock | number of dense blocks per layer. |
growthRate | number of filters to add for each dense block layer (default = 48). |
initialNumberOfFilters | number of filters at the beginning (default = 96). |
reductionRate | reduction factor of transition blocks |
depth | number of layers---must be equal to 3 * N + 4 where N is an integer (default = 7). |
dropoutRate | drop out layer rate (default = 0.2). |
weightDecay | weight decay (default = 1e-4). |
mode | A switch to determine the activation function to use.
If |
an DenseUnet keras model
X. Li, H. Chen, X. Qi, Q. Dou, C.-W. Fu, P.-A. Heng. H-DenseUNet: Hybrid Densely Connected UNet for Liver and Tumor Segmentation from CT Volumes
available here:
https://arxiv.org/pdf/1709.07330.pdf
with the author's implementation available at:
https://github.com/xmengli999/H-DenseUNet
Tustison NJ
library( ANTsR ) library( ANTsRNet ) library( keras ) imageIDs <- c( "r16", "r27", "r30", "r62", "r64", "r85" ) trainingBatchSize <- length( imageIDs ) # Perform simple 3-tissue segmentation. segmentationLabels <- c( 1, 2, 3 ) numberOfLabels <- length( segmentationLabels ) initialization <- paste0( 'KMeans[', numberOfLabels, ']' ) domainImage <- antsImageRead( getANTsRData( imageIDs[1] ) ) X_train <- array( data = NA, dim = c( trainingBatchSize, dim( domainImage ), 1 ) ) Y_train <- array( data = NA, dim = c( trainingBatchSize, dim( domainImage ) ) ) images <- list() segmentations <- list() for( i in seq_len( trainingBatchSize ) ) { cat( "Processing image", imageIDs[i], "\n" ) image <- antsImageRead( getANTsRData( imageIDs[i] ) ) mask <- getMask( image ) segmentation <- atropos( image, mask, initialization )$segmentation X_train[i,,, 1] <- as.array( image ) Y_train[i,,] <- as.array( segmentation ) }#> Processing image r16 #> Processing image r27 #> Processing image r30 #> Processing image r62 #> Processing image r64 #> Processing image r85Y_train <- encodeUnet( Y_train, segmentationLabels ) # Perform a simple normalization X_train <- ( X_train - mean( X_train ) ) / sd( X_train ) # Create the model model <- createDenseUnetModel2D( c( dim( domainImage ), 1L ), numberOfOutputs = numberOfLabels )#> Error in py_discover_config(required_module, use_environment): Python specified in RETICULATE_PYTHON (/Users/ntustison/anaconda3/envs/antsx/bin/python3) does not existmetric_multilabel_dice_coefficient <- custom_metric( "multilabel_dice_coefficient", multilabel_dice_coefficient )#> Error in value[[3L]](cond): The R function's signature must not contains esoteric Python-incompatible constructs. Detailed traceback: #> Python specified in RETICULATE_PYTHON (/Users/ntustison/anaconda3/envs/antsx/bin/python3) does not existloss_dice <- function( y_true, y_pred ) { -multilabel_dice_coefficient(y_true, y_pred) } attr(loss_dice, "py_function_name") <- "multilabel_dice_coefficient" model %>% compile( loss = loss_dice, optimizer = optimizer_adam( lr = 0.0001 ), metrics = c( metric_multilabel_dice_coefficient, metric_categorical_crossentropy ) )#> Error in compile(., loss = loss_dice, optimizer = optimizer_adam(lr = 1e-04), metrics = c(metric_multilabel_dice_coefficient, metric_categorical_crossentropy)): object 'model' not found# Comment out the rest due to travis build constraints # Fit the model # checkpoint_file = tempfile(fileext = ".h5") # track <- model %>% fit( X_train, Y_train, # epochs = 5, batch_size = 4, verbose = 1, shuffle = TRUE, # callbacks = list( # callback_model_checkpoint( checkpoint_file, # monitor = 'val_loss', save_best_only = TRUE ), # callback_reduce_lr_on_plateau( monitor = "val_loss", factor = 0.1 ) # ), # validation_split = 0.2 ) # Save the model and/or save the model weights # save_model_hdf5( model, filepath = 'unetModel.h5' ) # save_model_weights_hdf5( unetModel, filepath = 'unetModelWeights.h5' ) ) rm(model); gc()#> Warning: object 'model' not found#> used (Mb) gc trigger (Mb) limit (Mb) max used (Mb) #> Ncells 2509647 134.1 4570014 244.1 NA 4570014 244.1 #> Vcells 5995570 45.8 12255594 93.6 65536 10006078 76.4