这篇文章上次修改于 255 天前,可能其部分内容已经发生变化,如有疑问可询问作者。

前言

在开始这篇文章之前,我们得首先弄明白,什么是图像分割?

我们知道一个图像只不过是许多像素的集合。图像分割分类是对图像中属于特定类别的像素进行分类的过程,即像素级别的下游任务。因此图像分割简单来说就是按像素进行分类的问题。

传统的图像分割算法均是基于灰度值的不连续和相似的性质。而基于深度学习的图像分割技术则是利用卷积神经网络,来理解图像中的每个像素所代表的真实世界物体,这在以前是难以想象的。

1740235822031FRBrbtq8Zogp1oxNSIhcx7Z6nVb.png

语义分割(Semantic Segmentation)

** 定义**

“语义”是个很抽象的概念,在 2D 图像领域,每个像素点作为最小单位,它的像素值代表的就是一个特征,即“语义”信息。语义分割会为图像中的每个像素分配一个类别,但是同一类别之间的对象不会区分。而实例分割,只对特定的物体进行分类。这看起来与目标检测相似,不同的是目标检测输出目标的边界框和类别,实例分割输出的是目标的 Mask 和类别。具体而言,语义分割的目的是为了从像素级别理解图像的内容,并为图像中的每个像素分配一个对象类。

语义分割是一种将图像中的每个像素分配给特定类别的技术。其目标是识别图像中存在的各种对象和背景,并为每个像素分配相应的类别标签。例如,将图像中的像素划分为人、树、草地和天空等不同区域。是图像处理和机器视觉一个重要分支。与分类任务不同,语义分割需要判断图像每个像素点的类别,进行精确分割。语义分割目前在自动驾驶、自动抠图、医疗影像等领域有着比较广泛的应用。

** 特点**

  • 提供精确的像素级分类,有助于深入理解图像内容。
  • 无法区分同一类别中的不同实例。

** 语义分割的应用**

语义分割在多个领域有广泛应用:

  • 自动驾驶:用于道路、车辆和行人的识别。
  • 医学成像:用于组织和器官的分割。
  • 卫星遥感:用于土地覆盖分类。

** 常见模型**

FCN(Fully Convolutional Network)

  • 优点:简单易用,但是现在已经很少使用了,但它的历史贡献不可忽视。
  • 缺点:分割精度较低,可能无法很好地处理细节。

提出初衷

1740236704772Z5SNbToLOosLf9xBvllcevjVnEe.png
FCN(全卷积网络)模型的初衷是为了解决传统卷积神经网络(CNN)在语义分割任务中的局限性。具体而言,传统 CNN 使用全连接层进行分类,这会丢失图像的空间位置信息,导致其不适合像素级的预测任务。FCN 的核心动机包括:

  1. 实现端到端的像素级预测:FCN 通过将全连接层替换为卷积层,使得网络能够接受任意尺寸的输入图像,并输出与输入尺寸相同的像素级预测结果。
  2. 保留空间信息:取消全连接层后,FCN 能够保留图像的空间位置信息,从而更好地适应语义分割任务。
  3. 提高分割效率和精度:通过引入反卷积层(上采样层)和跳跃连接(Skip Connections),FCN 能够融合不同深度的特征,兼顾全局语义信息和局部细节,从而提升分割精度。
  4. 利用预训练模型加速训练:FCN 可以基于预训练的分类模型(如 AlexNet、VGG 等)进行微调,从而显著加速训练过程并提高模型性能。

网络结构

1740236686503VcqtbHWWQoRtJIxgy6bconZQnHc.png
通常 CNN 网络在卷积层之后会接上若干个全连接层, 将卷积层产生的特征图(feature map)映射成一个固定长度的特征向量。以 AlexNet 为代表的经典 CNN 结构适合于图像级的分类和回归任务,因为它们最后都期望得到整个输入图像的一个数值描述(概率)。
FCN 对图像进行像素级的分类,从而解决了语义级别的图像分割(semantic segmentation)问题。与经典的 CNN 在卷积层之后使用全连接层得到固定长度的特征向量进行分类(全连接层 +softmax 输出)不同,FCN 可以接受任意尺寸的输入图像,采用反卷积层对最后一个卷积层的 feature map 进行上采样, 使它恢复到输入图像相同的尺寸,从而可以对每个像素都产生了一个预测, 同时保留了原始输入图像中的空间信息, 最后在上采样的特征图上进行逐像素分类。
1740236781446CxZ6b4zcXo9A12xuNJtcEgA7nfc.png
FCN(全卷积网络)为了解决语义分割(semantic segmentation)问题而提出,它对图像进行像素级的分类,能够保留原始输入图像中的空间信息。与传统 CNN 不同,FCN 可以接受任意尺寸的输入图像,并通过以下方式实现像素级分类:

  1. 去除全连接层:FCN 将传统 CNN 中的全连接层替换为卷积层,从而保留空间信息。
  2. 上采样操作:使用反卷积层(上采样层)对最后一个卷积层的特征图进行上采样,恢复到与输入图像相同的尺寸。
  3. 逐像素分类:在上采样后的特征图上进行逐像素分类,为每个像素生成类别预测。

Q:FCN 是如何通过上采样操作恢复特征图的空间分辨率的?

模型代码

论文源码
import caffe
from caffe import layers as L, params as P
from caffe.coord_map import crop

def conv_relu(bottom, nout, ks=3, stride=1, pad=1):
    conv = L.Convolution(bottom, kernel_size=ks, stride=stride,
        num_output=nout, pad=pad,
        param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)])
    return conv, L.ReLU(conv, in_place=True)

def max_pool(bottom, ks=2, stride=2):
    return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)

def fcn(split):
    n = caffe.NetSpec()
    pydata_params = dict(split=split, mean=(104.00699, 116.66877, 122.67892),
            seed=1337)
    if split == 'train':
        pydata_params['sbdd_dir'] = '../data/sbdd/dataset'
        pylayer = 'SBDDSegDataLayer'
    else:
        pydata_params['voc_dir'] = '../data/pascal/VOC2011'
        pylayer = 'VOCSegDataLayer'
    n.data, n.label = L.Python(module='voc_layers', layer=pylayer,
            ntop=2, param_str=str(pydata_params))

    # the base net
    n.conv1_1, n.relu1_1 = conv_relu(n.data, 64, pad=100)
    n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64)
    n.pool1 = max_pool(n.relu1_2)

    n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128)
    n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128)
    n.pool2 = max_pool(n.relu2_2)

    n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256)
    n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256)
    n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256)
    n.pool3 = max_pool(n.relu3_3)

    n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512)
    n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512)
    n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512)
    n.pool4 = max_pool(n.relu4_3)

    n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512)
    n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512)
    n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512)
    n.pool5 = max_pool(n.relu5_3)

    # fully conv
    n.fc6, n.relu6 = conv_relu(n.pool5, 4096, ks=7, pad=0)
    n.drop6 = L.Dropout(n.relu6, dropout_ratio=0.5, in_place=True)
    n.fc7, n.relu7 = conv_relu(n.drop6, 4096, ks=1, pad=0)
    n.drop7 = L.Dropout(n.relu7, dropout_ratio=0.5, in_place=True)
    n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0,
        param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)])
    n.upscore2 = L.Deconvolution(n.score_fr,
        convolution_param=dict(num_output=21, kernel_size=4, stride=2,
            bias_term=False),
        param=[dict(lr_mult=0)])

    n.score_pool4 = L.Convolution(n.pool4, num_output=21, kernel_size=1, pad=0,
        param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)])
    n.score_pool4c = crop(n.score_pool4, n.upscore2)
    n.fuse_pool4 = L.Eltwise(n.upscore2, n.score_pool4c,
            operation=P.Eltwise.SUM)
    n.upscore_pool4 = L.Deconvolution(n.fuse_pool4,
        convolution_param=dict(num_output=21, kernel_size=4, stride=2,
            bias_term=False),
        param=[dict(lr_mult=0)])

    n.score_pool3 = L.Convolution(n.pool3, num_output=21, kernel_size=1, pad=0,
        param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)])
    n.score_pool3c = crop(n.score_pool3, n.upscore_pool4)
    n.fuse_pool3 = L.Eltwise(n.upscore_pool4, n.score_pool3c,
            operation=P.Eltwise.SUM)
    n.upscore8 = L.Deconvolution(n.fuse_pool3,
        convolution_param=dict(num_output=21, kernel_size=16, stride=8,
            bias_term=False),
        param=[dict(lr_mult=0)])

    n.score = crop(n.upscore8, n.data)
    n.loss = L.SoftmaxWithLoss(n.score, n.label,
            loss_param=dict(normalize=False, ignore_label=255))

    return n.to_proto()

def make_net():
    with open('train.prototxt', 'w') as f:
        f.write(str(fcn('train')))

    with open('val.prototxt', 'w') as f:
        f.write(str(fcn('seg11valid')))

if __name__ == '__main__':
    make_net()
fcn8_vgg
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
from math import ceil
import sys

import numpy as np
import tensorflow as tf

VGG_MEAN = [103.939, 116.779, 123.68]


class FCN8VGG:

    def __init__(self, vgg16_npy_path=None):
        if vgg16_npy_path is None:
            path = sys.modules[self.__class__.__module__].__file__
            # print path
            path = os.path.abspath(os.path.join(path, os.pardir))
            # print path
            path = os.path.join(path, "vgg16.npy")
            vgg16_npy_path = path
            logging.info("Load npy file from '%s'.", vgg16_npy_path)
        if not os.path.isfile(vgg16_npy_path):
            logging.error(("File '%s' not found. Download it from "
                           "ftp://mi.eng.cam.ac.uk/pub/mttt2/"
                           "models/vgg16.npy"), vgg16_npy_path)
            sys.exit(1)

        self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
        self.wd = 5e-4
        print("npy file loaded")

    def build(self, rgb, train=False, num_classes=20, random_init_fc8=False,
              debug=False, use_dilated=False):
        """
        Build the VGG model using loaded weights
        Parameters
        ----------
        rgb: image batch tensor
            Image in rgb shap. Scaled to Intervall [0, 255]
        train: bool
            Whether to build train or inference graph
        num_classes: int
            How many classes should be predicted (by fc8)
        random_init_fc8 : bool
            Whether to initialize fc8 layer randomly.
            Finetuning is required in this case.
        debug: bool
            Whether to print additional Debug Information.
        """
        # Convert RGB to BGR

        with tf.name_scope('Processing'):

            red, green, blue = tf.split(rgb, 3, 3)
            # assert red.get_shape().as_list()[1:] == [224, 224, 1]
            # assert green.get_shape().as_list()[1:] == [224, 224, 1]
            # assert blue.get_shape().as_list()[1:] == [224, 224, 1]
            bgr = tf.concat([
                blue - VGG_MEAN[0],
                green - VGG_MEAN[1],
                red - VGG_MEAN[2],
            ], 3)

            if debug:
                bgr = tf.Print(bgr, [tf.shape(bgr)],
                               message='Shape of input image: ',
                               summarize=4, first_n=1)

        self.conv1_1 = self._conv_layer(bgr, "conv1_1")
        self.conv1_2 = self._conv_layer(self.conv1_1, "conv1_2")
        self.pool1 = self._max_pool(self.conv1_2, 'pool1', debug)

        self.conv2_1 = self._conv_layer(self.pool1, "conv2_1")
        self.conv2_2 = self._conv_layer(self.conv2_1, "conv2_2")
        self.pool2 = self._max_pool(self.conv2_2, 'pool2', debug)

        self.conv3_1 = self._conv_layer(self.pool2, "conv3_1")
        self.conv3_2 = self._conv_layer(self.conv3_1, "conv3_2")
        self.conv3_3 = self._conv_layer(self.conv3_2, "conv3_3")
        self.pool3 = self._max_pool(self.conv3_3, 'pool3', debug)

        self.conv4_1 = self._conv_layer(self.pool3, "conv4_1")
        self.conv4_2 = self._conv_layer(self.conv4_1, "conv4_2")
        self.conv4_3 = self._conv_layer(self.conv4_2, "conv4_3")

        if use_dilated:
            pad = [[0, 0], [0, 0]]
            self.pool4 = tf.nn.max_pool(self.conv4_3, ksize=[1, 2, 2, 1],
                                        strides=[1, 1, 1, 1],
                                        padding='SAME', name='pool4')
            self.pool4 = tf.space_to_batch(self.pool4,
                                           paddings=pad, block_size=2)
        else:
            self.pool4 = self._max_pool(self.conv4_3, 'pool4', debug)

        self.conv5_1 = self._conv_layer(self.pool4, "conv5_1")
        self.conv5_2 = self._conv_layer(self.conv5_1, "conv5_2")
        self.conv5_3 = self._conv_layer(self.conv5_2, "conv5_3")
        if use_dilated:
            pad = [[0, 0], [0, 0]]
            self.pool5 = tf.nn.max_pool(self.conv5_3, ksize=[1, 2, 2, 1],
                                        strides=[1, 1, 1, 1],
                                        padding='SAME', name='pool5')
            self.pool5 = tf.space_to_batch(self.pool5,
                                           paddings=pad, block_size=2)
        else:
            self.pool5 = self._max_pool(self.conv5_3, 'pool5', debug)

        self.fc6 = self._fc_layer(self.pool5, "fc6")

        if train:
            self.fc6 = tf.nn.dropout(self.fc6, 0.5)

        self.fc7 = self._fc_layer(self.fc6, "fc7")
        if train:
            self.fc7 = tf.nn.dropout(self.fc7, 0.5)

        if use_dilated:
            self.pool5 = tf.batch_to_space(self.pool5, crops=pad, block_size=2)
            self.pool5 = tf.batch_to_space(self.pool5, crops=pad, block_size=2)
            self.fc7 = tf.batch_to_space(self.fc7, crops=pad, block_size=2)
            self.fc7 = tf.batch_to_space(self.fc7, crops=pad, block_size=2)
            return

        if random_init_fc8:
            self.score_fr = self._score_layer(self.fc7, "score_fr",
                                              num_classes)
        else:
            self.score_fr = self._fc_layer(self.fc7, "score_fr",
                                           num_classes=num_classes,
                                           relu=False)

        self.pred = tf.argmax(self.score_fr, dimension=3)

        self.upscore2 = self._upscore_layer(self.score_fr,
                                            shape=tf.shape(self.pool4),
                                            num_classes=num_classes,
                                            debug=debug, name='upscore2',
                                            ksize=4, stride=2)
        self.score_pool4 = self._score_layer(self.pool4, "score_pool4",
                                             num_classes=num_classes)
        self.fuse_pool4 = tf.add(self.upscore2, self.score_pool4)

        self.upscore4 = self._upscore_layer(self.fuse_pool4,
                                            shape=tf.shape(self.pool3),
                                            num_classes=num_classes,
                                            debug=debug, name='upscore4',
                                            ksize=4, stride=2)
        self.score_pool3 = self._score_layer(self.pool3, "score_pool3",
                                             num_classes=num_classes)
        self.fuse_pool3 = tf.add(self.upscore4, self.score_pool3)

        self.upscore32 = self._upscore_layer(self.fuse_pool3,
                                             shape=tf.shape(bgr),
                                             num_classes=num_classes,
                                             debug=debug, name='upscore32',
                                             ksize=16, stride=8)

        self.pred_up = tf.argmax(self.upscore32, dimension=3)

    def _max_pool(self, bottom, name, debug):
        pool = tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
                              padding='SAME', name=name)

        if debug:
            pool = tf.Print(pool, [tf.shape(pool)],
                            message='Shape of %s' % name,
                            summarize=4, first_n=1)
        return pool

    def _conv_layer(self, bottom, name):
        with tf.variable_scope(name) as scope:
            filt = self.get_conv_filter(name)
            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)

            relu = tf.nn.relu(bias)
            # Add summary to Tensorboard
            _activation_summary(relu)
            return relu

    def _fc_layer(self, bottom, name, num_classes=None,
                  relu=True, debug=False):
        with tf.variable_scope(name) as scope:
            shape = bottom.get_shape().as_list()

            if name == 'fc6':
                filt = self.get_fc_weight_reshape(name, [7, 7, 512, 4096])
            elif name == 'score_fr':
                name = 'fc8'  # Name of score_fr layer in VGG Model
                filt = self.get_fc_weight_reshape(name, [1, 1, 4096, 1000],
                                                  num_classes=num_classes)
            else:
                filt = self.get_fc_weight_reshape(name, [1, 1, 4096, 4096])

            self._add_wd_and_summary(filt, self.wd, "fc_wlosses")

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
            conv_biases = self.get_bias(name, num_classes=num_classes)
            bias = tf.nn.bias_add(conv, conv_biases)

            if relu:
                bias = tf.nn.relu(bias)
            _activation_summary(bias)

            if debug:
                bias = tf.Print(bias, [tf.shape(bias)],
                                message='Shape of %s' % name,
                                summarize=4, first_n=1)
            return bias

    def _score_layer(self, bottom, name, num_classes):
        with tf.variable_scope(name) as scope:
            # get number of input channels
            in_features = bottom.get_shape()[3].value
            shape = [1, 1, in_features, num_classes]
            # He initialization Sheme
            if name == "score_fr":
                num_input = in_features
                stddev = (2 / num_input)**0.5
            elif name == "score_pool4":
                stddev = 0.001
            elif name == "score_pool3":
                stddev = 0.0001
            # Apply convolution
            w_decay = self.wd

            weights = self._variable_with_weight_decay(shape, stddev, w_decay,
                                                       decoder=True)
            conv = tf.nn.conv2d(bottom, weights, [1, 1, 1, 1], padding='SAME')
            # Apply bias
            conv_biases = self._bias_variable([num_classes], constant=0.0)
            bias = tf.nn.bias_add(conv, conv_biases)

            _activation_summary(bias)

            return bias

    def _upscore_layer(self, bottom, shape,
                       num_classes, name, debug,
                       ksize=4, stride=2):
        strides = [1, stride, stride, 1]
        with tf.variable_scope(name):
            in_features = bottom.get_shape()[3].value

            if shape is None:
                # Compute shape out of Bottom
                in_shape = tf.shape(bottom)

                h = ((in_shape[1] - 1) * stride) + 1
                w = ((in_shape[2] - 1) * stride) + 1
                new_shape = [in_shape[0], h, w, num_classes]
            else:
                new_shape = [shape[0], shape[1], shape[2], num_classes]
            output_shape = tf.stack(new_shape)

            logging.debug("Layer: %s, Fan-in: %d" % (name, in_features))
            f_shape = [ksize, ksize, num_classes, in_features]

            # create
            num_input = ksize * ksize * in_features / stride
            stddev = (2 / num_input)**0.5

            weights = self.get_deconv_filter(f_shape)
            self._add_wd_and_summary(weights, self.wd, "fc_wlosses")
            deconv = tf.nn.conv2d_transpose(bottom, weights, output_shape,
                                            strides=strides, padding='SAME')

            if debug:
                deconv = tf.Print(deconv, [tf.shape(deconv)],
                                  message='Shape of %s' % name,
                                  summarize=4, first_n=1)

        _activation_summary(deconv)
        return deconv

    def get_deconv_filter(self, f_shape):
        width = f_shape[0]
        height = f_shape[1]
        f = ceil(width/2.0)
        c = (2 * f - 1 - f % 2) / (2.0 * f)
        bilinear = np.zeros([f_shape[0], f_shape[1]])
        for x in range(width):
            for y in range(height):
                value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
                bilinear[x, y] = value
        weights = np.zeros(f_shape)
        for i in range(f_shape[2]):
            weights[:, :, i, i] = bilinear

        init = tf.constant_initializer(value=weights,
                                       dtype=tf.float32)
        var = tf.get_variable(name="up_filter", initializer=init,
                              shape=weights.shape)
        return var

    def get_conv_filter(self, name):
        init = tf.constant_initializer(value=self.data_dict[name][0],
                                       dtype=tf.float32)
        shape = self.data_dict[name][0].shape
        print('Layer name: %s' % name)
        print('Layer shape: %s' % str(shape))
        var = tf.get_variable(name="filter", initializer=init, shape=shape)
        if not tf.get_variable_scope().reuse:
            weight_decay = tf.multiply(tf.nn.l2_loss(var), self.wd,
                                       name='weight_loss')
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                 weight_decay)
        _variable_summaries(var)
        return var

    def get_bias(self, name, num_classes=None):
        bias_wights = self.data_dict[name][1]
        shape = self.data_dict[name][1].shape
        if name == 'fc8':
            bias_wights = self._bias_reshape(bias_wights, shape[0],
                                             num_classes)
            shape = [num_classes]
        init = tf.constant_initializer(value=bias_wights,
                                       dtype=tf.float32)
        var = tf.get_variable(name="biases", initializer=init, shape=shape)
        _variable_summaries(var)
        return var

    def get_fc_weight(self, name):
        init = tf.constant_initializer(value=self.data_dict[name][0],
                                       dtype=tf.float32)
        shape = self.data_dict[name][0].shape
        var = tf.get_variable(name="weights", initializer=init, shape=shape)
        if not tf.get_variable_scope().reuse:
            weight_decay = tf.multiply(tf.nn.l2_loss(var), self.wd,
                                       name='weight_loss')
            tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                                 weight_decay)
        _variable_summaries(var)
        return var

    def _bias_reshape(self, bweight, num_orig, num_new):
        """ Build bias weights for filter produces with `_summary_reshape`

        """
        n_averaged_elements = num_orig//num_new
        avg_bweight = np.zeros(num_new)
        for i in range(0, num_orig, n_averaged_elements):
            start_idx = i
            end_idx = start_idx + n_averaged_elements
            avg_idx = start_idx//n_averaged_elements
            if avg_idx == num_new:
                break
            avg_bweight[avg_idx] = np.mean(bweight[start_idx:end_idx])
        return avg_bweight

    def _summary_reshape(self, fweight, shape, num_new):
        """ Produce weights for a reduced fully-connected layer.

        FC8 of VGG produces 1000 classes. Most semantic segmentation
        task require much less classes. This reshapes the original weights
        to be used in a fully-convolutional layer which produces num_new
        classes. To archive this the average (mean) of n adjanced classes is
        taken.

        Consider reordering fweight, to perserve semantic meaning of the
        weights.

        Args:
          fweight: original weights
          shape: shape of the desired fully-convolutional layer
          num_new: number of new classes


        Returns:
          Filter weights for `num_new` classes.
        """
        num_orig = shape[3]
        shape[3] = num_new
        assert(num_new < num_orig)
        n_averaged_elements = num_orig//num_new
        avg_fweight = np.zeros(shape)
        for i in range(0, num_orig, n_averaged_elements):
            start_idx = i
            end_idx = start_idx + n_averaged_elements
            avg_idx = start_idx//n_averaged_elements
            if avg_idx == num_new:
                break
            avg_fweight[:, :, :, avg_idx] = np.mean(
                fweight[:, :, :, start_idx:end_idx], axis=3)
        return avg_fweight

    def _variable_with_weight_decay(self, shape, stddev, wd, decoder=False):
        """Helper to create an initialized Variable with weight decay.

        Note that the Variable is initialized with a truncated normal
        distribution.
        A weight decay is added only if one is specified.

        Args:
          name: name of the variable
          shape: list of ints
          stddev: standard deviation of a truncated Gaussian
          wd: add L2Loss weight decay multiplied by this float. If None, weight
              decay is not added for this Variable.

        Returns:
          Variable Tensor
        """

        initializer = tf.truncated_normal_initializer(stddev=stddev)
        var = tf.get_variable('weights', shape=shape,
                              initializer=initializer)

        collection_name = tf.GraphKeys.REGULARIZATION_LOSSES
        if wd and (not tf.get_variable_scope().reuse):
            weight_decay = tf.multiply(
                tf.nn.l2_loss(var), wd, name='weight_loss')
            tf.add_to_collection(collection_name, weight_decay)
        _variable_summaries(var)
        return var

    def _add_wd_and_summary(self, var, wd, collection_name=None):
        if collection_name is None:
            collection_name = tf.GraphKeys.REGULARIZATION_LOSSES
        if wd and (not tf.get_variable_scope().reuse):
            weight_decay = tf.multiply(
                tf.nn.l2_loss(var), wd, name='weight_loss')
            tf.add_to_collection(collection_name, weight_decay)
        _variable_summaries(var)
        return var

    def _bias_variable(self, shape, constant=0.0):
        initializer = tf.constant_initializer(constant)
        var = tf.get_variable(name='biases', shape=shape,
                              initializer=initializer)
        _variable_summaries(var)
        return var

    def get_fc_weight_reshape(self, name, shape, num_classes=None):
        print('Layer name: %s' % name)
        print('Layer shape: %s' % shape)
        weights = self.data_dict[name][0]
        weights = weights.reshape(shape)
        if num_classes is not None:
            weights = self._summary_reshape(weights, shape,
                                            num_new=num_classes)
        init = tf.constant_initializer(value=weights,
                                       dtype=tf.float32)
        var = tf.get_variable(name="weights", initializer=init, shape=shape)
        return var


def _activation_summary(x):
    """Helper to create summaries for activations.

    Creates a summary that provides a histogram of activations.
    Creates a summary that measure the sparsity of activations.

    Args:
      x: Tensor
    Returns:
      nothing
    """
    # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
    # session. This helps the clarity of presentation on tensorboard.
    tensor_name = x.op.name
    # tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
    tf.summary.histogram(tensor_name + '/activations', x)
    tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))


def _variable_summaries(var):
    """Attach a lot of summaries to a Tensor."""
    if not tf.get_variable_scope().reuse:
        name = var.op.name
        logging.info("Creating Summary for: %s" % name)
        with tf.name_scope('summaries'):
            mean = tf.reduce_mean(var)
            tf.summary.scalar(name + '/mean', mean)
            with tf.name_scope('stddev'):
                stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
            tf.summary.scalar(name + '/sttdev', stddev)
            tf.summary.scalar(name + '/max', tf.reduce_max(var))
            tf.summary.scalar(name + '/min', tf.reduce_min(var))
            tf.summary.histogram(name, var)
fcn 调包
import torch.nn as nn
import torch.nn.functional as F
from abc import ABCMeta
import torchvision.models as models


def _maybe_pad(x, size):
    hpad = size[0] - x.shape[2]
    wpad = size[1] - x.shape[3]
    if hpad + wpad > 0:
        x = F.pad(x, (0, wpad, 0, hpad, 0, 0, 0, 0 ))
    return x


class VGGFCN(nn.Module, metaclass=ABCMeta):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        assert in_channels == 3
        self.n_classes = n_classes
        self.vgg16 = models.vgg16(pretrained=True)
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 4096, kernel_size=7, padding=3),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Conv2d(4096, 4096, kernel_size=1),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Conv2d(4096, n_classes, kernel_size=1),
        )

        self._initialize_weights()

    def _initialize_weights(self):
        self.classifier[0].weight.data = (
            self.vgg16.classifier[0].weight.data.view(
                self.classifier[0].weight.size())
        )
        self.classifier[3].weight.data = (
            self.vgg16.classifier[3].weight.data.view(
                self.classifier[3].weight.size())
        )


class VGGFCN32(VGGFCN):
    def forward(self, x):
        input_height, input_width = x.shape[2], x.shape[3]
        x = self.vgg16.features(x)
        x = self.classifier(x)
        x = F.interpolate(x, size=(input_height, input_width),
                          mode='bilinear', align_corners=True)
        return x


class VGGFCN16(VGGFCN):
    def __init__(self, in_channels, n_classes):
        super().__init__(in_channels, n_classes)
        self.score4 = nn.Conv2d(512, n_classes, kernel_size=1)
        self.upscale5 = nn.ConvTranspose2d(
            n_classes, n_classes, kernel_size=2, stride=2)

    def forward(self, x):
        input_height, input_width = x.shape[2], x.shape[3]
        pool4 = self.vgg16.features[:-7](x)
        pool5 = self.vgg16.features[-7:](pool4)
        pool5_upscaled = self.upscale5(self.classifier(pool5))
        pool4 = self.score4(pool4)
        x = pool4 + pool5_upscaled
        x = F.interpolate(x, size=(input_height, input_width),
                          mode='bilinear', align_corners=True)
        return x


class VGGFCN8(VGGFCN):
    def __init__(self, in_channels, n_classes):
        super().__init__(in_channels, n_classes)
        self.upscale4 = nn.ConvTranspose2d(
            n_classes, n_classes, kernel_size=2, stride=2)
        self.score4 = nn.Conv2d(
            512, n_classes, kernel_size=1, stride=1)
        self.score3 = nn.Conv2d(
            256, n_classes, kernel_size=1, stride=1)
        self.upscale5 = nn.ConvTranspose2d(
            n_classes, n_classes, kernel_size=2, stride=2)

    def forward(self, x):
        input_height, input_width = x.shape[2], x.shape[3]
        pool3 = self.vgg16.features[:-14](x)
        pool4 = self.vgg16.features[-14:-7](pool3)
        pool5 = self.vgg16.features[-7:](pool4)
        pool5_upscaled = self.upscale5(self.classifier(pool5))
        pool5_upscaled = _maybe_pad(pool5_upscaled, pool4.shape[2:])
        pool4_scores = self.score4(pool4)
        pool4_fused = pool4_scores + pool5_upscaled
        pool4_upscaled = self.upscale4(pool4_fused)
        pool4_upscaled = _maybe_pad(pool4_upscaled, pool3.shape[2:])
        x = self.score3(pool3) + pool4_upscaled
        x = F.interpolate(x, size=(input_height, input_width),
                          mode='bilinear', align_corners=True)
        return x

相关资源

U-Net

  • 优点:简单易用,适用于小数据集,尤其在医学图像分割中表现良好。
  • 缺点:容易过拟合,不太适合大规模数据集。

1740236677445SwkUbDFUtolVEKxc4dmcbrOknJf.png

提出初衷

U-Net 是一种经典且广泛使用的分割模型,以其简单、高效、易于理解和构建的特点而受到青睐,尤其适合从小数据集中进行训练。该模型最早于 2015 年在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中被提出,至今仍然是医学图像分割领域的重要基础模型。

1740236700448YM2HbihrkonVkixR0KdcPLcBnLc.png

1740236696443Voorb6LuJopG3exmtulcaySynjc.png

1740236682500NfPrbS5SNoEjsFxFFptcl471nhb.png

  1. Unet 提出的初衷是为了解决医学图像分割的问题;
  2. 一种 U 型的网络结构来获取上下文的信息和位置信息;
  3. 在 2015 年的 ISBI cell tracking 比赛中获得了多个第一,一开始这是为了解决细胞层面的分割的任务的

网络结构

1740236655458N3Vab29uNorFHwxpINwc5s6unkh.png

U-Net 网络是一种经典的编码器-解码器结构,因其整体结构形似大写的英文字母“U”而得名。它广泛应用于医学图像分割等领域。U-Net 的设计非常简洁:前半部分用于特征提取(编码器),后半部分用于上采样(解码器)。

编码器(Encoder)

编码器位于网络的左半部分,主要由多个下采样模块组成。每个模块包含两个 3×3 的卷积层(激活函数为 ReLU),后接一个 2×2 的最大池化(Max Pooling)层,用于特征提取和空间尺寸的减半。通过这种结构,编码器能够逐步提取图像的深层特征,同时扩大感受野。

解码器(Decoder)

解码器位于网络的右半部分,主要由上采样模块组成。每个模块包含一个 2×2 的反卷积层(上采样卷积层),用于将特征图的空间尺寸恢复到与编码器对应层相同的大小。随后,解码器通过特征拼接(concatenation)将上采样后的特征图与编码器中对应层的特征图进行通道级拼接,最后通过两个 3×3 的卷积层(激活函数为 ReLU)进一步融合特征。这种结构能够有效地结合深层特征和浅层特征,兼顾全局语义信息和局部细节。

特征融合方式

与 FCN 网络通过特征图对应像素值的相加来融合特征不同,U-Net 采用通道级拼接的方式。这种方式可以形成更厚的特征图,从而保留更多的细节信息,但也增加了显存的消耗。

U-Net 的优点

  1. 多尺度特征融合:U-Net 通过拼接深层和浅层特征图,能够充分利用不同层次的特征。浅层卷积关注纹理和细节特征,而深层网络关注更高级的语义特征。这种融合方式使得模型能够更好地处理复杂的分割任务。
  2. 边缘特征的保留:在下采样过程中,虽然会损失一些边缘特征,但通过特征拼接,解码器能够从编码器的浅层特征中找回这些丢失的边缘信息,从而提高分割的精度。

Unet 的好处我感觉是:网络层越深得到的特征图,有着更大的视野域,浅层卷积关注纹理特征,深层网络关注本质的那种特征,所以深层浅层特征都是有格子的意义的;另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的一个找回。

下面是一些与医学相关的数据集以及对应的提取码,有兴趣的同学可以下载下来跑一下。

**数据集名称**
**下载链接****提取码**
Cell dataset (dsb2018)
https://pan.baidu.com/share/init?surl=BaVrzYdrSP78CwYaRzZr1w5l54
Liver dataset
https://pan.baidu.com/share/init?surl=FljGCVzu7HPYpwAKvSVN4Q5l88
Cell dataset (isbi)
https://pan.baidu.com/share/init?surl=FkfnhU-RnYFZti62-f8AVA14rz
Lung dataset
https://pan.baidu.com/share/init?surl=sLFRmtG2TOTEgUKniJf7AAqdwo
Corneal Nerve dataset
https://pan.baidu.com/share/init?surl=T3-kS_FgYI6DeXv3n1I7bAih02
Eye Vessels (DRIVE dataset)
https://pan.baidu.com/share/init?surl=UkMLmdbM61N8ecgnKlAsPgf1ek
Esophagus and Esophagus Cancer dataset (First Affiliated Hospital of Sun Yat-sen University)
https://pan.baidu.com/share/init?surl=0b5arIQjNpiggwdkgYNHXQhivm

为什么 Unet 在医疗图像分割中表现好?

大多数医疗影像语义分割任务都会首先用 Unet 作为 baseline,当然上一章节讲解的 Unet 的优点肯定是可以当作这个问题的答案,这里谈一谈医疗影像的特点

根据网友的讨论,得到的结果:

  1. 医疗影像语义较为简单、结构固定。因此语义信息相比自动驾驶等较为单一,因此并不需要去筛选过滤无用的信息。医疗影像的所有特征都很重要,因此低级特征和高级语义特征都很重要,所以 U 型结构的 skip connection 结构(特征拼接)更好派上用场
  2. 医学影像的数据较少,获取难度大,数据量可能只有几百甚至不到 100,因此如果使用大型的网络例如 DeepLabv3+ 等模型,很容易过拟合。大型网络的优点是更强的图像表述能力,而较为简单、数量少的医学影像并没有那么多的内容需要表述,因此也有人发现在小数量级中,分割的 SOTA 模型与轻量的 Unet 并没有神恶魔优势
  3. 医学影像往往是多模态的。比方说 ISLES 脑梗竞赛中,官方提供了 CBF,MTT,CBV 等多中模态的数据(这一点听不懂也无妨)。因此医学影像任务中,往往需要自己设计网络去提取不同的模态特征,因此轻量结构简单的 Unet 可以有更大的操作空间。

Q:过拟合与模型复杂程度有关还和什么有关呢?

模型代码

源码
import torch.nn as nn
import torch
from torch import autograd
from functools import partial
import torch.nn.functional as F
from torchvision import models

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.**conv** = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, _padding_=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(_inplace_=True),
            nn.Conv2d(out_ch, out_ch, 3, _padding_=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(_inplace_=True)
        )

    def forward(self, input):
        return self.**conv**(input)

class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()

        self.**conv1** = DoubleConv(in_ch, 32)
        self.**pool1** = nn.MaxPool2d(2)
        self.**conv2** = DoubleConv(32, 64)
        self.**pool2** = nn.MaxPool2d(2)
        self.**conv3** = DoubleConv(64, 128)
        self.**pool3** = nn.MaxPool2d(2)
        self.**conv4** = DoubleConv(128, 256)
        self.**pool4** = nn.MaxPool2d(2)
        self.**conv5** = DoubleConv(256, 512)
        self.**up6** = nn.ConvTranspose2d(512, 256, 2, _stride_=2)
        self.**conv6** = DoubleConv(512, 256)
        self.**up7** = nn.ConvTranspose2d(256, 128, 2, _stride_=2)
        self.**conv7** = DoubleConv(256, 128)
        self.**up8** = nn.ConvTranspose2d(128, 64, 2, _stride_=2)
        self.**conv8** = DoubleConv(128, 64)
        self.**up9** = nn.ConvTranspose2d(64, 32, 2, _stride_=2)
        self.**conv9** = DoubleConv(64, 32)
        self.**conv10** = nn.Conv2d(32, out_ch, 1)

    def forward(self, x):
        _#print(x.shape)_
        c1 = self.**conv1**(x)
        p1 = self.**pool1**(c1)
        _#print(p1.shape)_
        c2 = self.**conv2**(p1)
        p2 = self.**pool2**(c2)
        _#print(p2.shape)_
        c3 = self.**conv3**(p2)
        p3 = self.**pool3**(c3)
        _#print(p3.shape)_
        c4 = self.**conv4**(p3)
        p4 = self.**pool4**(c4)
        _#print(p4.shape)_
        c5 = self.**conv5**(p4)
        up_6 = self.**up6**(c5)
        merge6 = torch.cat([up_6, c4], _dim_=1)
        c6 = self.**conv6**(merge6)
        up_7 = self.**up7**(c6)
        merge7 = torch.cat([up_7, c3], _dim_=1)
        c7 = self.**conv7**(merge7)
        up_8 = self.**up8**(c7)
        merge8 = torch.cat([up_8, c2], _dim_=1)
        c8 = self.**conv8**(merge8)
        up_9 = self.**up9**(c8)
        merge9 = torch.cat([up_9, c1], _dim_=1)
        c9 = self.**conv9**(merge9)
        c10 = self.**conv10**(c9)
        out = nn.Sigmoid()(c10)
        return out

nonlinearity = partial(F.relu, _inplace_=True)
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, n_filters):
        super(DecoderBlock, self).__init__()

        self.**conv1** = nn.Conv2d(in_channels, in_channels // 4, 1)
        self.**norm1** = nn.BatchNorm2d(in_channels // 4)
        self.**relu1** = nonlinearity

        self.**deconv2** = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, _stride_=2, _padding_=1, _output_padding_=1)
        self.**norm2** = nn.BatchNorm2d(in_channels // 4)
        self.**relu2** = nonlinearity

        self.**conv3** = nn.Conv2d(in_channels // 4, n_filters, 1)
        self.**norm3** = nn.BatchNorm2d(n_filters)
        self.**relu3** = nonlinearity

    def forward(self, x):
        x = self.**conv1**(x)
        x = self.**norm1**(x)
        x = self.**relu1**(x)
        x = self.**deconv2**(x)
        x = self.**norm2**(x)
        x = self.**relu2**(x)
        x = self.**conv3**(x)
        x = self.**norm3**(x)
        x = self.**relu3**(x)
        return x

class resnet34_unet(nn.Module):
    def __init__(self, num_classes=1, num_channels=3,pretrained=True):
        super(resnet34_unet, self).__init__()

        filters = [64, 128, 256, 512]
        resnet = models.resnet34(_pretrained_=pretrained)
        self.**firstconv** = resnet.**conv1**
        self.**firstbn** = resnet.**bn1**
        self.**firstrelu** = resnet.**relu**
        self.**firstmaxpool** = resnet.**maxpool**
        self.**encoder1** = resnet.**layer1**
        self.**encoder2** = resnet.**layer2**
        self.**encoder3** = resnet.**layer3**
        self.**encoder4** = resnet.**layer4**

        self.**decoder4** = DecoderBlock(512, filters[2])
        self.**decoder3** = DecoderBlock(filters[2], filters[1])
        self.**decoder2** = DecoderBlock(filters[1], filters[0])
        self.**decoder1** = DecoderBlock(filters[0], filters[0])

        self.**decoder4** = DecoderBlock(512, filters[2])
        self.**decoder3** = DecoderBlock(filters[2], filters[1])
        self.**decoder2** = DecoderBlock(filters[1], filters[0])
        self.**decoder1** = DecoderBlock(filters[0], filters[0])

        self.**finaldeconv1** = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
        self.**finalrelu1** = nonlinearity
        self.**finalconv2** = nn.Conv2d(32, 32, 3, _padding_=1)
        self.**finalrelu2** = nonlinearity
        self.**finalconv3** = nn.Conv2d(32, num_classes, 3, _padding_=1)

    def forward(self, x):
        _# Encoder_
        x = self.**firstconv**(x)
        x = self.**firstbn**(x)
        x = self.**firstrelu**(x)
        x = self.**firstmaxpool**(x)
        e1 = self.**encoder1**(x)
        e2 = self.**encoder2**(e1)
        e3 = self.**encoder3**(e2)
        e4 = self.**encoder4**(e3)

        _# Center_

        _# Decoder_
        d4 = self.**decoder4**(e4) + e3
        d3 = self.**decoder3**(d4) + e2
        d2 = self.**decoder2**(d3) + e1
        d1 = self.**decoder1**(d2)

        out = self.**finaldeconv1**(d1)
        out = self.**finalrelu1**(out)
        out = self.**finalconv2**(out)
        out = self.**finalrelu2**(out)
        out = self.**finalconv3**(out)

        return nn.Sigmoid()(out)
模型实现
import torch
import torch.nn as nn
import torch.nn.functional as F

class double_conv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,
                               kernel_size=kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,
                              kernel_size = kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
  
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out
  
class deconv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
                                        kernel_size = kernel_size,
                                       stride = strides,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
      
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out
  
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(1,8)
        self.layer2_conv = double_conv2d_bn(8,16)
        self.layer3_conv = double_conv2d_bn(16,32)
        self.layer4_conv = double_conv2d_bn(32,64)
        self.layer5_conv = double_conv2d_bn(64,128)
        self.layer6_conv = double_conv2d_bn(128,64)
        self.layer7_conv = double_conv2d_bn(64,32)
        self.layer8_conv = double_conv2d_bn(32,16)
        self.layer9_conv = double_conv2d_bn(16,8)
        self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
                                     stride=1,padding=1,bias=True)
      
        self.deconv1 = deconv2d_bn(128,64)
        self.deconv2 = deconv2d_bn(64,32)
        self.deconv3 = deconv2d_bn(32,16)
        self.deconv4 = deconv2d_bn(16,8)
      
        self.sigmoid = nn.Sigmoid()
      
    def forward(self,x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1,2)
      
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2,2)
      
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3,2)
      
        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4,2)
      
        conv5 = self.layer5_conv(pool4)
      
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1,conv4],dim=1)
        conv6 = self.layer6_conv(concat1)
      
        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2,conv3],dim=1)
        conv7 = self.layer7_conv(concat2)
      
        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3,conv2],dim=1)
        conv8 = self.layer8_conv(concat3)
      
        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4,conv1],dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp
  

model = Unet()
inp = torch.rand(10,1,224,224)
outp = model(inp)
print(outp.shape)
==> torch.Size([10, 1, 224, 224])

相关资源

项目

https://github.com/Andy-zhujunwen/UNET-ZOO?tab=readme-ov-file`
https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets
https://www.codewithgpu.com/i/bubbliiiing/unet-pytorch/UNet-PyTorch
`
https://huggingface.co/spaces/h2chen/demo_unet

博客

UNet详解(附图文和代码实现)-CSDN博客<br>
图像分割必备知识点 | Unet详解 理论+ 代码 - 忽逢桃林 - 博客园