Today I’m going to write about a kaggle competition I started working on recently. I will show you how to approach the problem using the U-Net neural model architecture in keras. In the TGS Salt Identification Challenge, you are asked to segment salt deposits beneath the Earth’s surface. So we are given a set of seismic images that are 101 x 101 pixels each and each pixel is classified as either salt or sediment. The goal of the competition is to segment regions that contain salt. A seismic image is produced from imaging the reflection coming from rock boundaries. The seismic image shows the boundaries between different rock types. Lets look at some of the images and the labels now.

In [1]:
import os
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

from tqdm import tqdm_notebook, tnrange
from itertools import chain
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from sklearn.model_selection import train_test_split

import tensorflow as tf

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D
from keras.layers.merge import concatenate, add
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
Using TensorFlow backend.
In [2]:
# Set some parameters
im_width = 128
im_height = 128
border = 5
path_train = '../input/train/'
path_test = '../input/test/'

Load the images

In [3]:
# Get and resize train images and masks
def get_data(path, train=True):
    ids = next(os.walk(path + "images"))[2]
    X = np.zeros((len(ids), im_height, im_width, 1), dtype=np.float32)
    if train:
        y = np.zeros((len(ids), im_height, im_width, 1), dtype=np.float32)
    print('Getting and resizing images ... ')
    for n, id_ in tqdm_notebook(enumerate(ids), total=len(ids)):
        # Load images
        img = load_img(path + '/images/' + id_, grayscale=True)
        x_img = img_to_array(img)
        x_img = resize(x_img, (128, 128, 1), mode='constant', preserve_range=True)

        # Load masks
        if train:
            mask = img_to_array(load_img(path + '/masks/' + id_, grayscale=True))
            mask = resize(mask, (128, 128, 1), mode='constant', preserve_range=True)

        # Save images
        X[n, ..., 0] = x_img.squeeze() / 255
        if train:
            y[n] = mask / 255
    print('Done!')
    if train:
        return X, y
    else:
        return X
    
X, y = get_data(path_train, train=True)
Getting and resizing images ...
Done!
In [4]:
# Split train and valid
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.15, random_state=2018)
In [5]:
# Check if training data looks all right
ix = random.randint(0, len(X_train))
has_mask = y_train[ix].max() > 0

fig, ax = plt.subplots(1, 2, figsize=(20, 10))

ax[0].imshow(X_train[ix, ..., 0], cmap='seismic', interpolation='bilinear')
if has_mask:
    ax[0].contour(y_train[ix].squeeze(), colors='k', levels=[0.5])
ax[0].set_title('Seismic')

ax[1].imshow(y_train[ix].squeeze(), interpolation='bilinear', cmap='gray')
ax[1].set_title('Salt');

The U-Net model

A successfull and popular model for these kind of problems is the UNet architecture. The network architecture is illustrated in Figure 1. It consists of a contracting path (left side) and an expansive path (right side).

unet image segmentation architecture

Figure 1: UNet

  • The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3×3 convolutions, each followed by a batchnormalization layer and a rectified linear unit (ReLU) activation and dropout and a 2×2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels. The purpose of this contracting path is to capture the context of the input image in order to be able to do segmentation.
  • Every step in the expansive path consists of an upsampling of the feature map followed by a 2×2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly feature map from the contracting path, and two 3×3 convolutions, each followed by batchnorm, dropout and a ReLU. The purpose of this expanding path is to enable precise localization combined with contextual information from the contracting path.
  • At the final layer a 1×1 convolution is used to map each 16- component feature vector to the desired number of classes.
In [6]:
def conv2d_block(input_tensor, n_filters, kernel_size=3, batchnorm=True):
    # first layer
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer="he_normal",
               padding="same")(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation("relu")(x)
    # second layer
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer="he_normal",
               padding="same")(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x
In [7]:
def get_unet(input_img, n_filters=16, dropout=0.5, batchnorm=True):
    # contracting path
    c1 = conv2d_block(input_img, n_filters=n_filters*1, kernel_size=3, batchnorm=batchnorm)
    p1 = MaxPooling2D((2, 2)) (c1)
    p1 = Dropout(dropout*0.5)(p1)

    c2 = conv2d_block(p1, n_filters=n_filters*2, kernel_size=3, batchnorm=batchnorm)
    p2 = MaxPooling2D((2, 2)) (c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters=n_filters*4, kernel_size=3, batchnorm=batchnorm)
    p3 = MaxPooling2D((2, 2)) (c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters=n_filters*8, kernel_size=3, batchnorm=batchnorm)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)
    p4 = Dropout(dropout)(p4)
    
    c5 = conv2d_block(p4, n_filters=n_filters*16, kernel_size=3, batchnorm=batchnorm)
    
    # expansive path
    u6 = Conv2DTranspose(n_filters*8, (3, 3), strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters=n_filters*8, kernel_size=3, batchnorm=batchnorm)

    u7 = Conv2DTranspose(n_filters*4, (3, 3), strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters=n_filters*4, kernel_size=3, batchnorm=batchnorm)

    u8 = Conv2DTranspose(n_filters*2, (3, 3), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters=n_filters*2, kernel_size=3, batchnorm=batchnorm)

    u9 = Conv2DTranspose(n_filters*1, (3, 3), strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters=n_filters*1, kernel_size=3, batchnorm=batchnorm)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model

Since dropout seems not to work well for me in this competition we set it to a low value. Batchnormalization improves the training quite a lot.

In [8]:
input_img = Input((im_height, im_width, 1), name='img')
model = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)

model.compile(optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy"])
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
img (InputLayer)                (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 16) 160         img[0][0]                        
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 16) 64          conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 16) 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 128, 128, 16) 2320        activation_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128, 128, 16) 64          conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 128, 128, 16) 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 64, 64, 16)   0           activation_2[0][0]               
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 64, 64, 16)   0           max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 32)   4640        dropout_1[0][0]                  
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 64, 64, 32)   128         conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 64, 64, 32)   0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 32)   9248        activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 32)   128         conv2d_4[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 64, 64, 32)   0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 32, 32, 32)   0           activation_4[0][0]               
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 32, 32, 32)   0           max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 32, 32, 64)   18496       dropout_2[0][0]                  
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 64)   256         conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 32, 32, 64)   0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 32, 32, 64)   36928       activation_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 32, 32, 64)   256         conv2d_6[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 64)   0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 16, 16, 64)   0           activation_6[0][0]               
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 16, 16, 64)   0           max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 16, 16, 128)  73856       dropout_3[0][0]                  
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 16, 16, 128)  512         conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 16, 16, 128)  0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 16, 16, 128)  147584      activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 16, 16, 128)  512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 16, 16, 128)  0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 8, 8, 128)    0           activation_8[0][0]               
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 8, 8, 128)    0           max_pooling2d_4[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 256)    295168      dropout_4[0][0]                  
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 8, 8, 256)    1024        conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 8, 8, 256)    0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 8, 8, 256)    590080      activation_9[0][0]               
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 8, 8, 256)    1024        conv2d_10[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 8, 8, 256)    0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 16, 16, 128)  295040      activation_10[0][0]              
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 16, 16, 256)  0           conv2d_transpose_1[0][0]         
                                                                 activation_8[0][0]               
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 16, 16, 256)  0           concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 16, 16, 128)  295040      dropout_5[0][0]                  
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 16, 16, 128)  512         conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 16, 16, 128)  0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 16, 16, 128)  147584      activation_11[0][0]              
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 16, 16, 128)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 16, 16, 128)  0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 32, 32, 64)   73792       activation_12[0][0]              
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 128)  0           conv2d_transpose_2[0][0]         
                                                                 activation_6[0][0]               
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 32, 32, 128)  0           concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 32, 32, 64)   73792       dropout_6[0][0]                  
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 32, 32, 64)   256         conv2d_13[0][0]                  
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 32, 32, 64)   0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 32, 32, 64)   36928       activation_13[0][0]              
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 64)   256         conv2d_14[0][0]                  
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 32, 32, 64)   0           batch_normalization_14[0][0]     
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 64, 64, 32)   18464       activation_14[0][0]              
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 64)   0           conv2d_transpose_3[0][0]         
                                                                 activation_4[0][0]               
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 64, 64, 64)   0           concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 32)   18464       dropout_7[0][0]                  
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 64, 64, 32)   128         conv2d_15[0][0]                  
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 64, 64, 32)   0           batch_normalization_15[0][0]     
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 64, 64, 32)   9248        activation_15[0][0]              
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 64, 64, 32)   128         conv2d_16[0][0]                  
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 64, 64, 32)   0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 128, 128, 16) 4624        activation_16[0][0]              
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 128, 128, 32) 0           conv2d_transpose_4[0][0]         
                                                                 activation_2[0][0]               
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 128, 128, 32) 0           concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 128, 128, 16) 4624        dropout_8[0][0]                  
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 128, 128, 16) 64          conv2d_17[0][0]                  
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 128, 128, 16) 0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 128, 128, 16) 2320        activation_17[0][0]              
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 128, 128, 16) 64          conv2d_18[0][0]                  
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 128, 128, 16) 0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 128, 128, 1)  17          activation_18[0][0]              
==================================================================================================
Total params: 2,164,305
Trainable params: 2,161,361
Non-trainable params: 2,944
__________________________________________________________________________________________________

Now we can train the model. We use some callbacks to save the model while training, lower the learning rate if the validation loss plateaues and perform early stopping.

In [9]:
callbacks = [
    EarlyStopping(patience=10, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=3, min_lr=0.00001, verbose=1),
    ModelCheckpoint('model-tgs-salt.h5', verbose=1, save_best_only=True, save_weights_only=True)
]
In [10]:
results = model.fit(X_train, y_train, batch_size=32, epochs=100, callbacks=callbacks,
                    validation_data=(X_valid, y_valid))
Train on 3400 samples, validate on 600 samples
Epoch 1/100
3400/3400 [==============================] - 37s 11ms/step - loss: 0.3942 - acc: 0.8351 - val_loss: 1.1071 - val_acc: 0.6706

Epoch 00001: val_loss improved from inf to 1.10712, saving model to model-tgs-salt.h5
Epoch 2/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.2880 - acc: 0.8840 - val_loss: 0.5479 - val_acc: 0.8455

Epoch 00002: val_loss improved from 1.10712 to 0.54791, saving model to model-tgs-salt.h5
Epoch 3/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.2590 - acc: 0.8935 - val_loss: 0.3577 - val_acc: 0.8623

Epoch 00003: val_loss improved from 0.54791 to 0.35774, saving model to model-tgs-salt.h5
Epoch 4/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.2447 - acc: 0.8975 - val_loss: 0.4058 - val_acc: 0.8325

Epoch 00004: val_loss did not improve
Epoch 5/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.2397 - acc: 0.8985 - val_loss: 0.4099 - val_acc: 0.8431

Epoch 00005: val_loss did not improve
Epoch 6/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.2172 - acc: 0.9073 - val_loss: 0.2367 - val_acc: 0.9045

Epoch 00006: val_loss improved from 0.35774 to 0.23671, saving model to model-tgs-salt.h5
Epoch 7/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1962 - acc: 0.9179 - val_loss: 0.3124 - val_acc: 0.8932

Epoch 00007: val_loss did not improve
Epoch 8/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.2019 - acc: 0.9100 - val_loss: 0.2329 - val_acc: 0.9104

Epoch 00008: val_loss improved from 0.23671 to 0.23290, saving model to model-tgs-salt.h5
Epoch 9/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1873 - acc: 0.9182 - val_loss: 0.2109 - val_acc: 0.9044

Epoch 00009: val_loss improved from 0.23290 to 0.21089, saving model to model-tgs-salt.h5
Epoch 10/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1844 - acc: 0.9191 - val_loss: 0.2302 - val_acc: 0.9164

Epoch 00010: val_loss did not improve
Epoch 11/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1717 - acc: 0.9251 - val_loss: 0.2095 - val_acc: 0.9182

Epoch 00011: val_loss improved from 0.21089 to 0.20949, saving model to model-tgs-salt.h5
Epoch 12/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1711 - acc: 0.9231 - val_loss: 0.1764 - val_acc: 0.9227

Epoch 00012: val_loss improved from 0.20949 to 0.17644, saving model to model-tgs-salt.h5
Epoch 13/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1715 - acc: 0.9234 - val_loss: 0.2119 - val_acc: 0.9276

Epoch 00013: val_loss did not improve
Epoch 14/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1618 - acc: 0.9273 - val_loss: 0.1842 - val_acc: 0.9260

Epoch 00014: val_loss did not improve
Epoch 15/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1516 - acc: 0.9310 - val_loss: 0.2616 - val_acc: 0.8964

Epoch 00015: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.

Epoch 00015: val_loss did not improve
Epoch 16/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1311 - acc: 0.9401 - val_loss: 0.1491 - val_acc: 0.9348

Epoch 00016: val_loss improved from 0.17644 to 0.14907, saving model to model-tgs-salt.h5
Epoch 17/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1237 - acc: 0.9431 - val_loss: 0.1485 - val_acc: 0.9359

Epoch 00017: val_loss improved from 0.14907 to 0.14852, saving model to model-tgs-salt.h5
Epoch 18/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1135 - acc: 0.9457 - val_loss: 0.1508 - val_acc: 0.9377

Epoch 00018: val_loss did not improve
Epoch 19/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1178 - acc: 0.9444 - val_loss: 0.1498 - val_acc: 0.9358

Epoch 00019: val_loss did not improve
Epoch 20/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1097 - acc: 0.9487 - val_loss: 0.1504 - val_acc: 0.9353

Epoch 00020: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.

Epoch 00020: val_loss did not improve
Epoch 21/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1086 - acc: 0.9493 - val_loss: 0.1490 - val_acc: 0.9352

Epoch 00021: val_loss did not improve
Epoch 22/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1053 - acc: 0.9499 - val_loss: 0.1490 - val_acc: 0.9352

Epoch 00022: val_loss did not improve
Epoch 23/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1045 - acc: 0.9505 - val_loss: 0.1495 - val_acc: 0.9346

Epoch 00023: ReduceLROnPlateau reducing learning rate to 1e-05.

Epoch 00023: val_loss did not improve
Epoch 24/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1026 - acc: 0.9511 - val_loss: 0.1495 - val_acc: 0.9347

Epoch 00024: val_loss did not improve
Epoch 25/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1015 - acc: 0.9519 - val_loss: 0.1499 - val_acc: 0.9342

Epoch 00025: val_loss did not improve
Epoch 26/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1035 - acc: 0.9502 - val_loss: 0.1499 - val_acc: 0.9348

Epoch 00026: val_loss did not improve
Epoch 27/100
3400/3400 [==============================] - 30s 9ms/step - loss: 0.1027 - acc: 0.9516 - val_loss: 0.1500 - val_acc: 0.9350

Epoch 00027: val_loss did not improve
Epoch 00027: early stopping
In [11]:
plt.figure(figsize=(8, 8))
plt.title("Learning curve")
plt.plot(results.history["loss"], label="loss")
plt.plot(results.history["val_loss"], label="val_loss")
plt.plot( np.argmin(results.history["val_loss"]), np.min(results.history["val_loss"]), marker="x", color="r", label="best model")
plt.xlabel("Epochs")
plt.ylabel("log_loss")
plt.legend();

Inference with the model

Now we can load the best model we saved and look at some predictions.

In [12]:
# Load best model
model.load_weights('model-tgs-salt.h5')
In [13]:
# Evaluate on validation set (this must be equals to the best log_loss)
model.evaluate(X_valid, y_valid, verbose=1)
600/600 [==============================] - 1s 2ms/step
Out[13]:
[0.14852104584376016, 0.9359497062365214]
In [14]:
# Predict on train, val and test
preds_train = model.predict(X_train, verbose=1)
preds_val = model.predict(X_valid, verbose=1)

# Threshold predictions
preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_val_t = (preds_val > 0.5).astype(np.uint8)
3400/3400 [==============================] - 9s 3ms/step
600/600 [==============================] - 1s 2ms/step

An important step we skip here is to select an appropriate threshold for the model. This is normaly done by optimizing the threshold on a holdout set. Let’s look at some predictions.

In [15]:
def plot_sample(X, y, preds, binary_preds, ix=None):
    if ix is None:
        ix = random.randint(0, len(X))

    has_mask = y[ix].max() > 0

    fig, ax = plt.subplots(1, 4, figsize=(20, 10))
    ax[0].imshow(X[ix, ..., 0], cmap='seismic')
    if has_mask:
        ax[0].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[0].set_title('Seismic')

    ax[1].imshow(y[ix].squeeze())
    ax[1].set_title('Salt')

    ax[2].imshow(preds[ix].squeeze(), vmin=0, vmax=1)
    if has_mask:
        ax[2].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[2].set_title('Salt Predicted')
    
    ax[3].imshow(binary_preds[ix].squeeze(), vmin=0, vmax=1)
    if has_mask:
        ax[3].contour(y[ix].squeeze(), colors='k', levels=[0.5])
    ax[3].set_title('Salt Predicted binary');
In [16]:
# Check if training data looks all right
plot_sample(X_train, y_train, preds_train, preds_train_t, ix=14)
In [17]:
# Check if valid data looks all right
plot_sample(X_valid, y_valid, preds_val, preds_val_t, ix=19)

This looks like we are going in the right direction, but obviously there is space left to improve. But that’s it for now. Now you can use the model to predict on the test images and submit your predictions to the competition. If you use an appropriate method to choose the threshold, this should give you a score around 0.7 on the learderboard.

I hope you enjoyed this post and learned something. Join the competition and try the model yourself. In the next post, I will show you how to improve the model with data augmentation and test time augmentation.

Further reading:

You might also be interested in: