I'm trying to train a 2D UNET convolutional network with a (256,256,2) input to predict an output of the same dimensions. I've been having some issues with preparing the training data for loading which I though was the problem, but the model summary tells me the error occurs in the first convolutional layer. All filters are (3,3). How can I fix this issue?
I considered that I might need to use 3D layers instead of 2D, but I've seen tutorials that use RGB images with 2D convolutional layers, so I don't think this is the issue.
InvalidArgumentError: input depth must be evenly divisible by filter depth: 3 vs 2
[[node model_7/conv2d_105/BiasAdd (defined at <ipython-input-9-957db582cfa0>:171) ]] [Op:__inference_train_function_19507]
Architecture
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_8 (InputLayer) [(None, 256, 256, 2) 0
__________________________________________________________________________________________________
conv2d_105 (Conv2D) (None, 256, 256, 16) 304 input_8[0][0]
__________________________________________________________________________________________________
batch_normalization_98 (BatchNo (None, 256, 256, 16) 64 conv2d_105[0][0]
__________________________________________________________________________________________________
activation_98 (Activation) (None, 256, 256, 16) 0 batch_normalization_98[0][0]
__________________________________________________________________________________________________
conv2d_106 (Conv2D) (None, 256, 256, 16) 2320 activation_98[0][0]
__________________________________________________________________________________________________
batch_normalization_99 (BatchNo (None, 256, 256, 16) 64 conv2d_106[0][0]
__________________________________________________________________________________________________
activation_99 (Activation) (None, 256, 256, 16) 0 batch_normalization_99[0][0]
__________________________________________________________________________________________________
average_pooling2d_21 (AveragePo (None, 128, 128, 16) 0 activation_99[0][0]
__________________________________________________________________________________________________
conv2d_107 (Conv2D) (None, 128, 128, 32) 4640 average_pooling2d_21[0][0]
__________________________________________________________________________________________________
batch_normalization_100 (BatchN (None, 128, 128, 32) 128 conv2d_107[0][0]
__________________________________________________________________________________________________
activation_100 (Activation) (None, 128, 128, 32) 0 batch_normalization_100[0][0]
__________________________________________________________________________________________________
conv2d_108 (Conv2D) (None, 128, 128, 32) 9248 activation_100[0][0]
__________________________________________________________________________________________________
batch_normalization_101 (BatchN (None, 128, 128, 32) 128 conv2d_108[0][0]
__________________________________________________________________________________________________
activation_101 (Activation) (None, 128, 128, 32) 0 batch_normalization_101[0][0]
__________________________________________________________________________________________________
average_pooling2d_22 (AveragePo (None, 64, 64, 32) 0 activation_101[0][0]
__________________________________________________________________________________________________
conv2d_109 (Conv2D) (None, 64, 64, 64) 18496 average_pooling2d_22[0][0]
__________________________________________________________________________________________________
batch_normalization_102 (BatchN (None, 64, 64, 64) 256 conv2d_109[0][0]
__________________________________________________________________________________________________
activation_102 (Activation) (None, 64, 64, 64) 0 batch_normalization_102[0][0]
__________________________________________________________________________________________________
conv2d_110 (Conv2D) (None, 64, 64, 64) 36928 activation_102[0][0]
__________________________________________________________________________________________________
batch_normalization_103 (BatchN (None, 64, 64, 64) 256 conv2d_110[0][0]
__________________________________________________________________________________________________
activation_103 (Activation) (None, 64, 64, 64) 0 batch_normalization_103[0][0]
__________________________________________________________________________________________________
average_pooling2d_23 (AveragePo (None, 32, 32, 64) 0 activation_103[0][0]
__________________________________________________________________________________________________
conv2d_111 (Conv2D) (None, 32, 32, 128) 73856 average_pooling2d_23[0][0]
__________________________________________________________________________________________________
batch_normalization_104 (BatchN (None, 32, 32, 128) 512 conv2d_111[0][0]
__________________________________________________________________________________________________
activation_104 (Activation) (None, 32, 32, 128) 0 batch_normalization_104[0][0]
__________________________________________________________________________________________________
conv2d_112 (Conv2D) (None, 32, 32, 128) 147584 activation_104[0][0]
__________________________________________________________________________________________________
batch_normalization_105 (BatchN (None, 32, 32, 128) 512 conv2d_112[0][0]
__________________________________________________________________________________________________
activation_105 (Activation) (None, 32, 32, 128) 0 batch_normalization_105[0][0]
__________________________________________________________________________________________________
conv2d_transpose_21 (Conv2DTran (None, 64, 64, 64) 32832 activation_105[0][0]
__________________________________________________________________________________________________
concatenate_21 (Concatenate) (None, 64, 64, 128) 0 conv2d_transpose_21[0][0]
activation_103[0][0]
__________________________________________________________________________________________________
conv2d_113 (Conv2D) (None, 64, 64, 64) 73792 concatenate_21[0][0]
__________________________________________________________________________________________________
batch_normalization_106 (BatchN (None, 64, 64, 64) 256 conv2d_113[0][0]
__________________________________________________________________________________________________
activation_106 (Activation) (None, 64, 64, 64) 0 batch_normalization_106[0][0]
__________________________________________________________________________________________________
conv2d_114 (Conv2D) (None, 64, 64, 64) 36928 activation_106[0][0]
__________________________________________________________________________________________________
batch_normalization_107 (BatchN (None, 64, 64, 64) 256 conv2d_114[0][0]
__________________________________________________________________________________________________
activation_107 (Activation) (None, 64, 64, 64) 0 batch_normalization_107[0][0]
__________________________________________________________________________________________________
conv2d_transpose_22 (Conv2DTran (None, 128, 128, 32) 8224 activation_107[0][0]
__________________________________________________________________________________________________
concatenate_22 (Concatenate) (None, 128, 128, 64) 0 conv2d_transpose_22[0][0]
activation_101[0][0]
__________________________________________________________________________________________________
conv2d_115 (Conv2D) (None, 128, 128, 32) 18464 concatenate_22[0][0]
__________________________________________________________________________________________________
batch_normalization_108 (BatchN (None, 128, 128, 32) 128 conv2d_115[0][0]
__________________________________________________________________________________________________
activation_108 (Activation) (None, 128, 128, 32) 0 batch_normalization_108[0][0]
__________________________________________________________________________________________________
conv2d_116 (Conv2D) (None, 128, 128, 32) 9248 activation_108[0][0]
__________________________________________________________________________________________________
batch_normalization_109 (BatchN (None, 128, 128, 32) 128 conv2d_116[0][0]
__________________________________________________________________________________________________
activation_109 (Activation) (None, 128, 128, 32) 0 batch_normalization_109[0][0]
__________________________________________________________________________________________________
conv2d_transpose_23 (Conv2DTran (None, 256, 256, 16) 2064 activation_109[0][0]
__________________________________________________________________________________________________
concatenate_23 (Concatenate) (None, 256, 256, 32) 0 conv2d_transpose_23[0][0]
activation_99[0][0]
__________________________________________________________________________________________________
conv2d_117 (Conv2D) (None, 256, 256, 16) 4624 concatenate_23[0][0]
__________________________________________________________________________________________________
batch_normalization_110 (BatchN (None, 256, 256, 16) 64 conv2d_117[0][0]
__________________________________________________________________________________________________
activation_110 (Activation) (None, 256, 256, 16) 0 batch_normalization_110[0][0]
__________________________________________________________________________________________________
conv2d_118 (Conv2D) (None, 256, 256, 16) 3856 activation_110[0][0]
__________________________________________________________________________________________________
batch_normalization_111 (BatchN (None, 256, 256, 16) 64 conv2d_118[0][0]
__________________________________________________________________________________________________
activation_111 (Activation) (None, 256, 256, 16) 0 batch_normalization_111[0][0]
__________________________________________________________________________________________________
conv2d_119 (Conv2D) (None, 256, 256, 2) 34 activation_111[0][0]
Model code for first conv block
padding = 'same'
filter = (3,3)
inputs = Input((256, 256, 2))
k_init = 'he_norma