Source code for

# MIT License
# Copyright (c) 2017 Vooban Inc.
# Coded by: Guillaume Chevalier
# Source to original code and license:

"""Do smooth predictions on an image from tiled prediction patches."""

import numpy as np
import scipy.signal
from tqdm import tqdm
import gc
import math

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    # See end of file for the rest of the __main__.

def _spline_window(window_size, power=2):
    Squared spline (power=2) window function:**2,+y%3D-(x-2)**2+%2B2,+y%3D(x-4)**2,+from+y+%3D+0+to+2
    intersection = int(window_size/4)
    wind_outer = (abs(2*(scipy.signal.triang(window_size))) ** power)/2
    wind_outer[intersection:-intersection] = 0

    wind_inner = 1 - (abs(2*(scipy.signal.triang(window_size) - 1)) ** power)/2
    wind_inner[:intersection] = 0
    wind_inner[-intersection:] = 0

    wind = wind_inner + wind_outer
    wind = wind / np.average(wind)
    return wind

cached_2d_windows = dict()
def _window_2D(window_size, power=2):
    Make a 1D window function, then infer and return a 2D window function.
    Done with an augmentation, and self multiplication with its transpose.
    Could be generalized to more dimensions.
    # Memoization
    global cached_2d_windows
    key = "{}_{}".format(window_size, power)
    if key in cached_2d_windows:
        wind = cached_2d_windows[key]
        wind = _spline_window(window_size, power)   
        wind = np.expand_dims(np.expand_dims(wind, -1), -1)
        wind = wind * wind.transpose(1, 0, 2)
        if PLOT_PROGRESS:
            # For demo purpose, let's look once at the window:
            plt.imshow(wind[:, :, 0], cmap="viridis")
            plt.title("2D Windowing Function for a Smooth Blending of "
                      "Overlapping Patches")
        cached_2d_windows[key] = wind
    return wind

def _pad_img(img, window_size, subdivisions):
    Add borders to img for a "valid" border pattern according to "window_size" and
    Image is an np array of shape (y, x, nb_channels).
    aug = int(round(window_size * (1 - 1.0/subdivisions)))
    more_borders = ((aug, aug), (aug, aug), (0, 0))
    ret = np.pad(img, pad_width=more_borders, mode='reflect')
    # gc.collect()

        # For demo purpose, let's look once at the window:
        plt.title("Padded Image for Using Tiled Prediction Patches\n"
                  "(notice the reflection effect on the padded borders)")
    return ret

def _unpad_img(padded_img, window_size, subdivisions):
    Undo what's done in the `_pad_img` function.
    Image is an np array of shape (y, x, nb_channels).
    aug = int(round(window_size * (1 - 1.0/subdivisions)))
    ret = padded_img[
    # gc.collect()
    return ret

def _rotate_mirror_do(im):
    Duplicate an np array (image) of shape (y, x, nb_channels) 8 times, in order
    to have all the possible rotations and mirrors of that image that fits the
    possible 90 degrees rotations.

    It is the D_4 (D4) Dihedral group:
    mirrs = []
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=1))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=2))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=3))
    im = np.array(im)[:, ::-1]
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=1))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=2))
    mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=3))
    return mirrs

def _rotate_mirror_undo(im_mirrs):
    merges a list of 8 np arrays (images) of shape (y, x, nb_channels) generated
    from the `_rotate_mirror_do` function. Each images might have changed and
    merging them implies to rotated them back in order and average things out.

    It is the D_4 (D4) Dihedral group:
    origs = []
    origs.append(np.rot90(np.array(im_mirrs[1]), axes=(0, 1), k=3))
    origs.append(np.rot90(np.array(im_mirrs[2]), axes=(0, 1), k=2))
    origs.append(np.rot90(np.array(im_mirrs[3]), axes=(0, 1), k=1))
    origs.append(np.array(im_mirrs[4])[:, ::-1])
    origs.append(np.rot90(np.array(im_mirrs[5]), axes=(0, 1), k=3)[:, ::-1])
    origs.append(np.rot90(np.array(im_mirrs[6]), axes=(0, 1), k=2)[:, ::-1])
    origs.append(np.rot90(np.array(im_mirrs[7]), axes=(0, 1), k=1)[:, ::-1])
    return np.mean(origs, axis=0)

def _windowed_subdivs(padded_img, window_size, subdivisions, n_classes, pred_func):
    Create tiled overlapping patches.

        5D numpy array of shape = (

        patches_resolution_along_X == patches_resolution_along_Y == window_size
    WINDOW_SPLINE_2D = _window_2D(window_size=window_size, power=2)

    step = int(window_size/subdivisions)
    padx_len = padded_img.shape[0]
    pady_len = padded_img.shape[1]
    subdivs = []

    for i in range(0, padx_len-window_size+1, step):
        for j in range(0, pady_len-window_size+1, step):
            patch = padded_img[i:i+window_size, j:j+window_size, :]

    # Here, `gc.collect()` clears RAM between operations.
    # It should run faster if they are removed, if enough memory is available.
    subdivs = np.array(subdivs)
    a, b, c, d, e = subdivs.shape
    subdivs = subdivs.reshape(a * b, c, d, e)

    subdivs = pred_func(subdivs)
    if isinstance(subdivs, list):                                             
        subdivs = subdivs[-1]
    subdivs = np.array([patch * WINDOW_SPLINE_2D for patch in subdivs])

    # Such 5D array:
    subdivs = subdivs.reshape(a, b, c, d, n_classes)

    return subdivs

[docs]def create_gen(subdivs, subdivs_m, subdivs_w): gen = zip(subdivs, subdivs_m, subdivs_w) for (img, label, weights) in gen: yield ([img, weights], label)
def _recreate_from_subdivs(subdivs, window_size, subdivisions, padded_out_shape): """ Merge tiled overlapping patches smoothly. """ step = int(window_size/subdivisions) padx_len = padded_out_shape[0] pady_len = padded_out_shape[1] y = np.zeros(padded_out_shape) a = 0 for i in range(0, padx_len-window_size+1, step): b = 0 for j in range(0, pady_len-window_size+1, step): windowed_patch = subdivs[a, b] y[i:i+window_size, j:j+window_size] = y[i:i+window_size, j:j+window_size] + windowed_patch b += 1 a += 1 return y / (subdivisions ** 2)
[docs]def predict_img_with_smooth_windowing(input_img, window_size, subdivisions, n_classes, pred_func): """ Apply the `pred_func` function to square patches of the image, and overlap the predictions to merge them smoothly. See 6th, 7th and 8th idea here: """ pad = _pad_img(input_img, window_size, subdivisions) pads = _rotate_mirror_do(pad) # Note that the implementation could be more memory-efficient by merging # the behavior of `_windowed_subdivs` and `_recreate_from_subdivs` into # one loop doing in-place assignments to the new image matrix, rather than # using a temporary 5D array. # It would also be possible to allow different (and impure) window functions # that might not tile well. Adding their weighting to another matrix could # be done to later normalize the predictions correctly by dividing the whole # reconstructed thing by this matrix of weightings - to normalize things # back from an impure windowing function that would have badly weighted # windows. # For example, since the U-net of Kaggle's DSTL satellite imagery feature # prediction challenge's 3rd place winners use a different window size for # the input and output of the neural net's patches predictions, it would be # possible to fake a full-size window which would in fact just have a narrow # non-zero dommain. This may require to augment the `subdivisions` argument # to 4 rather than 2. res = [] for pad in tqdm(pads): # For every rotation: sd = _windowed_subdivs(pad, window_size, subdivisions, n_classes, pred_func) one_padded_result = _recreate_from_subdivs( sd, window_size, subdivisions, padded_out_shape=list(pad.shape[:-1])+[n_classes]) res.append(one_padded_result) # Merge after rotations: padded_results = _rotate_mirror_undo(res) prd = _unpad_img(padded_results, window_size, subdivisions) prd = prd[:input_img.shape[0], :input_img.shape[1], :] if PLOT_PROGRESS: plt.imshow(prd) plt.title("Smoothly Merged Patches that were Tiled Tighter") return prd
[docs]def predict_img_with_overlap(input_img, window_size, subdivisions, n_classes, pred_func): """Based on predict_img_with_smooth_windowing but works just with the original image instead of creating 8 new ones. """ pad = _pad_img(input_img, window_size, subdivisions) sd = _windowed_subdivs(pad, window_size, subdivisions, n_classes, pred_func) one_padded_result = _recreate_from_subdivs(sd, window_size, subdivisions, padded_out_shape=list(pad.shape[:-1])+[n_classes]) prd = _unpad_img(one_padded_result, window_size, subdivisions) prd = prd[:input_img.shape[0], :input_img.shape[1], :] if PLOT_PROGRESS: plt.imshow(prd) plt.title("Smoothly Merged Patches that were Tiled Tighter") return prd