Custom network

This tutorial shows how to define a new network and add it to DeepReg, using a specific example for adding a GlobalNet to predict an affine transformation, as opposed to nonrigid transformation.

For general guidance on making a contribution to DeepReg, see the contribution guidelines.

Step 1: Create network backbone

The first step is to create a new backbone class, which consists of the neural network architecture you want to use, and place it in the backbone directory deepreg/model/backbone/. The affine method uses the GlobalNet network architecture (deepreg/model/backbone/global_net.py) from Hu et al. 2018. The GlobalNet network has an encoder-only architecture, which is used to predict the parameters of an affine transformation model, with 12 degrees of freedom.

We recommend using the tf.keras API to write your network, by defining the layers of your backbone class in def __init__() and the network’s forward pass in def call(). Custom DeepReg layers can be found in deepreg/model/layer.py.

class GlobalNet(tf.keras.Model):
    """
    Builds GlobalNet for image registration based on
    Y. Hu et al.,
    "Label-driven weakly-supervised learning for multimodal
    deformable image registration,"
    (ISBI 2018), pp. 1070-1074.
    """

   def __init__(
        self,
        image_size,
        out_channels,
        num_channel_initial,
        extract_levels,
        out_kernel_initializer,
        out_activation,
        **kwargs,
    ):
        """
        Image is encoded gradually, i from level 0 to E.
        Then, a densely-connected layer outputs an affine
        transformation.
        :param out_channels: int, number of channels for the output
        :param num_channel_initial: int, number of initial channels
        :param extract_levels: list, which levels from net to extract
        :param out_activation: str, activation at last layer
        :param out_kernel_initializer: str, which kernel to use as initialiser
        :param kwargs:
        """
        super(GlobalNet, self).__init__(**kwargs)
        # save parameters
        self._extract_levels = extract_levels
        self._extract_max_level = max(self._extract_levels)  # E
        self.reference_grid = layer_util.get_reference_grid(image_size)
        self.transform_initial = tf.constant_initializer(
            value=[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
        )
        # init layer variables
        num_channels = [
            num_channel_initial * (2 ** level)
            for level in range(self._extract_max_level + 1)
        ]  # level 0 to E
        self._downsample_blocks = [
            layer.DownSampleResnetBlock(
                filters=num_channels[i], kernel_size=7 if i == 0 else 3
            )
            for i in range(self._extract_max_level)
        ]  # level 0 to E-1
        self._conv3d_block = layer.Conv3dBlock(filters=num_channels[-1])  # level E
        self._dense_layer = layer.Dense(
            units=12, bias_initializer=self.transform_initial
        )

    def call(self, inputs, training=None, mask=None):
        """
        Build GlobalNet graph based on built layers.
        :param inputs: image batch, shape = [batch, f_dim1, f_dim2, f_dim3, ch]
        :param training:
        :param mask:
        :return:
        """
        # down sample from level 0 to E
        h_in = inputs
        for level in range(self._extract_max_level):  # level 0 to E - 1
            h_in, _ = self._downsample_blocks[level](inputs=h_in, training=training)
        h_out = self._conv3d_block(
            inputs=h_in, training=training
        )  # level E of encoding
        # predict affine parameters theta of shape = [batch, 4, 3]
        self.theta = self._dense_layer(h_out)
        self.theta = tf.reshape(self.theta, shape=(-1, 4, 3))
        # warp the reference grid with affine parameters to output a ddf
        grid_warped = layer_util.warp_grid(self.reference_grid, self.theta)
        output = grid_warped - self.reference_grid
        return output`

In order to use the backbone network in the DeepReg pipeline, a new option needs to be added to build_backbone() from deepreg/model/network/util.py. We use the keyword “global” here to refer to our GlobalNet class and "affine" for the method name. This will allow us to define the backbone network directly in the configuration file.

def build_backbone(
    image_size: tuple, out_channels: int, model_config: dict, method_name: str
) -> tf.keras.Model:
    """
    Backbone model accepts a single input of shape (batch, dim1, dim2, dim3, ch_in)
    and returns a single output of shape (batch, dim1, dim2, dim3, ch_out)
    :param image_size: tuple, dims of image, (dim1, dim2, dim3)
    :param out_channels: int, number of out channels, ch_out
    :param method_name: str, one of ddf | dvf | conditional
    :param model_config: dict, model configuration, returned from parser.yaml.load
    :return: tf.keras.Model
    """
    if not (
        (isinstance(image_size, tuple) or isinstance(image_size, list))
        and len(image_size) == 3
    ):
        raise ValueError(f"image_size must be tuple of length 3, got {image_size}")
    if not (isinstance(out_channels, int) and out_channels >= 1):
        raise ValueError(f"out_channels must be int >=1, got {out_channels}")
    if not (isinstance(model_config, dict) and "backbone" in model_config.keys()):
        raise ValueError(
            f"model_config must be a dict having key 'backbone', got{model_config}"
        )
    if method_name not in ["ddf", "dvf", "conditional", "affine"]:
        raise ValueError(
            "method name has to be one of ddf/dvf/conditional/affine in build_backbone, "
            "got {}".format(method_name)
        )

    if method_name in ["ddf", "dvf"]:
        out_activation = None
        # TODO try random init with smaller number
        out_kernel_initializer = "zeros"  # to ensure small ddf and dvf
    elif method_name in ["conditional"]:
        out_activation = "sigmoid"  # output is probability
        out_kernel_initializer = "glorot_uniform"
    elif method_name in ["affine"]:
        out_activation = None
        out_kernel_initializer = "zeros"
    else:
        raise ValueError("Unknown method name {}".format(method_name))

    if model_config["backbone"] == "local":
        return LocalNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["local"],
        )
    elif model_config["backbone"] == "global":
        return GlobalNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["global"],
        )
    elif model_config["backbone"] == "unet":
        return UNet(
            image_size=image_size,
            out_channels=out_channels,
            out_kernel_initializer=out_kernel_initializer,
            out_activation=out_activation,
            **model_config["unet"],
        )
    else:
        raise ValueError("Unknown model name")

Step 2: Create network model

We can now create a network model for the affine method in deepreg/model/network/affine.py . We first need to write the model’s forward pass, which makes use of the backbone network class to predict an affine transformation which will be used to output a dense displacement field (DDF).

def affine_forward(
    backbone: tf.keras.Model,
    moving_image: tf.Tensor,
    fixed_image: tf.Tensor,
    moving_label: (tf.Tensor, None),
    moving_image_size: tuple,
    fixed_image_size: tuple,
):
    """
    Perform the network forward pass
    :param backbone: model architecture object, e.g. model.backbone.local_net
    :param moving_image: tensor of shape (batch, m_dim1, m_dim2, m_dim3)
    :param fixed_image:  tensor of shape (batch, f_dim1, f_dim2, f_dim3)
    :param moving_label: tensor of shape (batch, m_dim1, m_dim2, m_dim3) or None
    :param moving_image_size: tuple like (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: tuple like (f_dim1, f_dim2, f_dim3)
    :return: tuple(_affine, _ddf, _pred_fixed_image, _pred_fixed_label)
    :return: tuple(affine, ddf, pred_fixed_image, pred_fixed_label, fixed_grid), where
    - affine is the affine transformation matrix predicted by the network (batch, 4, 3)
    - ddf is the dense displacement field of shape (batch, f_dim1, f_dim2, f_dim3, 3)
    - pred_fixed_image is the predicted (warped) moving image of shape (batch, f_dim1, f_dim2, f_dim3)
    - pred_fixed_label is the predicted (warped) moving label of shape (batch, f_dim1, f_dim2, f_dim3)
    - fixed_grid is the grid of shape(f_dim1, f_dim2, f_dim3, 3)
    """

    # expand dims
    # need to be squeezed later for warping
    moving_image = tf.expand_dims(
        moving_image, axis=4
    )  # (batch, m_dim1, m_dim2, m_dim3, 1)
    fixed_image = tf.expand_dims(
        fixed_image, axis=4
    )  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # adjust moving image
    moving_image = layer_util.resize3d(
        image=moving_image, size=fixed_image_size
    )  # (batch, f_dim1, f_dim2, f_dim3, 1)

    # ddf, dvf
    inputs = tf.concat(
        [moving_image, fixed_image], axis=4
    )  # (batch, f_dim1, f_dim2, f_dim3, 2)
    ddf = backbone(inputs=inputs)  # (batch, f_dim1, f_dim2, f_dim3, 3)
    affine = backbone.theta

    # prediction, (batch, f_dim1, f_dim2, f_dim3)
    warping = layer.Warping(fixed_image_size=fixed_image_size)
    grid_fixed = tf.squeeze(warping.grid_ref, axis=0)  # (f_dim1, f_dim2, f_dim3, 3)
    pred_fixed_image = warping(inputs=[ddf, tf.squeeze(moving_image, axis=4)])
    pred_fixed_label = (
        warping(inputs=[ddf, moving_label]) if moving_label is not None else None
    )
    return affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed

Similar to build_backbone we also need to write the build_affine_model function, which consists of building the model according to the networks’ inputs, backbone and loss function.

def build_affine_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    model_config: dict,
    loss_config: dict,
    ):
    """
    :param moving_image_size: (m_dim1, m_dim2, m_dim3)
    :param fixed_image_size: (f_dim1, f_dim2, f_dim3)
    :param index_size: int, the number of indices for identifying a sample
    :param labeled: bool, indicating if the data is labeled
    :param batch_size: int, size of mini-batch
    :param model_config: config for the model
    :param loss_config: config for the loss
    :return: the built tf.keras.Model
    """

    # inputs
    (moving_image, fixed_image, moving_label, fixed_label, indices) = build_inputs(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=index_size,
        batch_size=batch_size,
        labeled=labeled,
    )

    # backbone
    backbone = build_backbone(
        image_size=fixed_image_size,
        out_channels=3,
        model_config=model_config,
        method_name=model_config["method"],
    )

    # forward
    affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward(
        backbone=backbone,
        moving_image=moving_image,
        fixed_image=fixed_image,
        moving_label=moving_label,
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )

    # build model
    inputs = {
        "moving_image": moving_image,
        "fixed_image": fixed_image,
        "indices": indices,
    }
    outputs = {"ddf": ddf, "affine": affine}
    model_name = model_config["method"].upper() + "RegistrationModel"
    if moving_label is None:  # unlabeled
        model = tf.keras.Model(
            inputs=inputs, outputs=outputs, name=model_name + "WithoutLabel"
        )
    else:  # labeled
        inputs["moving_label"] = moving_label
        inputs["fixed_label"] = fixed_label
        outputs["pred_fixed_label"] = pred_fixed_label
        model = tf.keras.Model(
            inputs=inputs, outputs=outputs, name=model_name + "WithLabel"
        )

    # add loss and metric
    model = add_ddf_loss(model=model, ddf=ddf, loss_config=loss_config)
    model = add_image_loss(
        model=model,
        fixed_image=fixed_image,
        pred_fixed_image=pred_fixed_image,
        loss_config=loss_config,
    )
    model = add_label_loss(
        model=model,
        grid_fixed=grid_fixed,
        fixed_label=fixed_label,
        pred_fixed_label=pred_fixed_label,
        loss_config=loss_config,
    )

    return model

Finally, the last step consists of adding the build_affine_model option to deepreg/model/network/build.py to be able to parse it from the configuration file.

def build_model(
    moving_image_size: tuple,
    fixed_image_size: tuple,
    index_size: int,
    labeled: bool,
    batch_size: int,
    model_config: dict,
    loss_config: dict,
):
    """
    Parsing algorithm types to model building functions
    :param moving_image_size: [m_dim1, m_dim2, m_dim3]
    :param fixed_image_size: [f_dim1, f_dim2, f_dim3]
    :param index_size: dataset size
    :param labeled: true if the label of moving/fixed images are provided
    :param batch_size: mini-batch size
    :param model_config: model configuration, e.g. dictionary return from parser.yaml.load
    :param loss_config: loss configuration, e.g. dictionary return from parser.yaml.load
    :return: the built tf.keras.Model
    """
    if model_config["method"] in ["ddf", "dvf"]:
        return build_ddf_dvf_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            model_config=model_config,
            loss_config=loss_config,
        )
    elif model_config["method"] == "conditional":
        return build_conditional_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            model_config=model_config,
            loss_config=loss_config,
        )
    elif model_config["method"] == "affine":
        return build_affine_model(
            moving_image_size=moving_image_size,
            fixed_image_size=fixed_image_size,
            index_size=index_size,
            labeled=labeled,
            batch_size=batch_size,
            model_config=model_config,
            loss_config=loss_config,
        )
    else:
        raise ValueError("Unknown model method")

Step 3: Testing (for contributing developers, optional for users)

Everyone is warmly welcome to make contributions to DeepReg and add corresponding unit test for the newly added functions to test/unit/. Recommendations regarding testing style can be found in the contribution guidelines. Here is a practical example of unit tests made for our affine model in test/unit/test_affine.py:

def test_affine_forward():
    """
    Testing that affine_forward function returns the tensors with correct shapes
    """

    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)
    batch_size = 1

    global_net = build_backbone(
        image_size=fixed_image_size,
        out_channels=3,
        model_config={
            "backbone": "global",
            "global": {"num_channel_initial": 4, "extract_levels": [1, 2, 3]},
        },
        method_name="affine",
    )

    # Check conditional mode network output shapes - Pass
    affine, ddf, pred_fixed_image, pred_fixed_label, grid_fixed = affine_forward(
        backbone=global_net,
        moving_image=tf.ones((batch_size,) + moving_image_size),
        fixed_image=tf.ones((batch_size,) + fixed_image_size),
        moving_label=tf.ones((batch_size,) + moving_image_size),
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
    )
    assert affine.shape == (batch_size,) + (4,) + (3,)
    assert ddf.shape == (batch_size,) + fixed_image_size + (3,)
    assert pred_fixed_image.shape == (batch_size,) + fixed_image_size
    assert pred_fixed_label.shape == (batch_size,) + fixed_image_size
    assert grid_fixed.shape == fixed_image_size + (3,)


def test_build_affine_model():
    """
    Testing that build_affine_model function returns the tensors with correct shapes
    """
    moving_image_size = (1, 3, 5)
    fixed_image_size = (2, 4, 6)
    batch_size = 1

    model = build_affine_model(
        moving_image_size=moving_image_size,
        fixed_image_size=fixed_image_size,
        index_size=1,
        labeled=True,
        batch_size=batch_size,
        model_config={
            "method": "affine",
            "backbone": "global",
            "global": {"num_channel_initial": 4, "extract_levels": [1, 2, 3]},
        },
        loss_config={
            "dissimilarity": {
                "image": {"name": "lncc", "weight": 0.1},
                "label": {
                    "name": "multi_scale",
                    "weight": 1,
                    "multi_scale": {
                        "loss_type": "dice",
                        "loss_scales": [0, 1, 2, 4, 8, 16, 32],
                    },
                },
            },
            "regularization": {"weight": 0.0, "energy_type": "bending"},
        },
    )

    inputs = {
        "moving_image": tf.ones((batch_size,) + moving_image_size),
        "fixed_image": tf.ones((batch_size,) + fixed_image_size),
        "indices": 1,
        "moving_label": tf.ones((batch_size,) + moving_image_size),
        "fixed_label": tf.ones((batch_size,) + fixed_image_size),
    }

    outputs = model(inputs)

    expected_outputs_keys = ["affine", "ddf", "pred_fixed_label"]
    assert all(keys in expected_outputs_keys for keys in outputs)
    assert outputs["pred_fixed_label"].shape == (batch_size,) + fixed_image_size
    assert outputs["affine"].shape == (batch_size,) + (4,) + (3,)
    assert outputs["ddf"].shape == (batch_size,) + fixed_image_size + (3,)

Step 4: Set yaml configuration files

An example of yaml configuration file for the affine method is available in config/unpaired_labeled_affine.yaml. For using both the GlobalNet backbone and affine method you will need to add their aforementioned keyword “global” and “affine”. Optional parameters such as out_kernel_initializer or num_channel_initial can also be specified. A snippet of config/unpaired_labeled_affine.yaml is shown below. Please see the configuration documentation for more details.

model:
method: "affine"
backbone:
  name: "global"
  out_kernel_initializer: "zeros"
  out_activation: ""
global:
  num_channel_initial: 1
  extract_levels: [0, 1, 2, 3, 4]