nnU-Net 模型学习笔记
0. 背景
这篇主要是对论文 nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation 的学习记录。
一开始看到 nnU-Net 的时候,其实很容易把它理解成“一个更厉害的 U-Net 变体”。但是读完论文之后发现,它真正重要的地方好像并不是提出了什么特别炫的新网络结构,而是把医学图像分割里一堆容易被忽略的工程细节系统化了。
论文里也直接把 nnU-Net 解释成 no-new-Net,这个名字还挺有意思的。它大概想表达的是:
不一定非要发明一个新的网络结构,先把 U-Net 该做好的事情全部做好。
这点对我还挺有提醒意义的。因为之前做 DICOM 预处理、期相配准、nnU-Net 输入格式的时候,我也经常会下意识觉得“模型才是核心”,但实际上数据怎么裁剪、怎么重采样、spacing 怎么统一、标签怎么插值、patch 怎么取,这些东西如果没处理好,后面的模型可能只是看起来很努力。
1. nnU-Net 想解决什么问题
医学图像分割里面,U-Net 本身已经是一个很经典的结构了。
但是问题在于,不同数据集之间差异非常大:
有的是 CT,有的是 MRI;
有的是 2D 图像,有的是 3D 体数据;
有的 spacing 比较均匀,有的 z 轴很厚;
有的图像范围很大,比如肝脏、肺;
有的目标很小,比如海马体、胰腺病灶;
有的数据集病例很多,有的数据集病例很少。
所以同样是 U-Net,直接拿来用并不一定能跑出好结果。
如果每来一个新任务,都靠人工去调网络深度、patch size、batch size、归一化方式、数据增强方式、推理策略,那整个流程就会非常依赖经验,也很难保证可复现。
nnU-Net 的思路就是:
给定一个新的医学图像分割任务,尽量自动分析数据集特征,然后自动生成一套适合这个任务的 U-Net 训练和推理方案。
所以它不是单纯的一个模型,而更像是一个完整的自动化分割框架。
2. 核心理解:它不追求“新”,而追求“配好”
论文里有一个观点我觉得挺关键:
很多所谓的新结构,在没有把 baseline 充分调好的情况下,看起来确实能提升性能。但如果基础 U-Net 已经被认真配置过,很多花哨结构带来的提升可能就没有那么明显了。
所以 nnU-Net 没有采用 residual connection、dense connection、attention mechanism 这些当时常见的改造,而是只在原始 U-Net 上做了很小的改动:
使用 leaky ReLU;
使用 instance normalization;
encoder 和 decoder 仍然保持比较朴素的 U-Net 结构;
重点放在自动化配置网络、预处理、训练、推理和后处理。
这个地方让我感觉,nnU-Net 的强并不是“网络结构很神秘”,而是它把一个医学图像分割任务中真正影响性能的变量都认真管起来了。
如果用一句话概括:
nnU-Net 不是换了一辆更夸张的车,而是把路线、轮胎、油量、驾驶方式和终点检查都安排好了。
3. nnU-Net 中的三类 U-Net
论文里主要准备了三种基础模型:
2D U-Net;
3D U-Net;
U-Net Cascade。
它们不是互相替代的关系,而是针对不同数据特点各有适用场景。
3.1 2D U-Net
2D U-Net 每次处理的是二维切片。
直觉上看,用 2D 网络处理 3D 医学图像好像有点不充分,因为它看不到 z 轴方向的上下文信息。
但是论文里提到,对于一些各向异性很强的数据,2D U-Net 反而可能是合理的选择。比如 z 轴 spacing 很大、层厚比较厚的时候,相邻切片之间的信息本来就没有那么连续,强行使用 3D 卷积不一定更好。
这点和我现在处理多期相 CT 也有关系。不能一看到医学图像是三维的,就默认 3D 网络一定更优。还要看:
z 轴 spacing;
切片厚度;
上下层解剖结构是否连续;
目标区域是否需要强三维上下文。
所以 2D U-Net 在 nnU-Net 里面不是落后选项,而是一个针对特定数据几何特征的候选方案。
3.2 3D U-Net
3D U-Net 更符合三维医学图像的直觉。
它可以同时利用 x、y、z 三个方向的信息,对于器官、肿瘤这类三维结构来说,理论上会更适合。
但是问题也很现实:
显存不够。
理想情况下,当然希望把整个病人的 3D 图像都送进网络。可是实际图像可能非常大,比如肝脏 CT 的体素数量很高,不可能直接整幅输入,所以只能基于 patch 训练。
patch-based training 又会带来另一个问题:
patch 太小,模型看不到足够大的上下文;
patch 太大,batch size 又会太小,甚至显存爆掉;
对大器官来说,局部 patch 可能很难判断自己到底处在什么解剖位置。
所以 3D U-Net 虽然看起来是医学图像分割的自然选择,但它仍然要在 patch size、batch size 和显存之间不断权衡。
3.3 U-Net Cascade
U-Net Cascade 是为了解决大图像中 3D U-Net 视野不够的问题。
它分成两个阶段:
低分辨率 3D U-Net
↓
先在降采样图像上得到一个粗分割
↓
把粗分割上采样回原始分辨率
↓
作为额外输入通道交给第二个 3D U-Net
↓
在全分辨率 patch 上做精细分割
简单理解就是:
第一阶段先让模型在低分辨率下看到更大的全局范围,知道大概哪里是目标;第二阶段再回到高分辨率,利用粗分割结果辅助局部细化。
这个设计对大器官或者大范围图像比较有意义。比如 Liver、Lung、Pancreas 这类任务,单纯全分辨率 3D patch 可能视野太小,所以需要先用低分辨率结果提供一个全局提示。
这也提醒我,分割不只是“边界画得准不准”,还有一个很重要的问题是:
模型到底有没有看到足够多的上下文。
4. 自动配置网络结构
nnU-Net 会根据数据集的中位数图像大小和 spacing 自动决定网络配置。
论文里提到的几个关键变量包括:
input patch size;
batch size;
每个轴上的 pooling 次数;
feature map 数量;
是否需要 cascade。
这里最重要的是,它不是手动给每个任务写一套参数,而是根据数据几何特征自动生成。
比如 3D U-Net 默认从 128 × 128 × 128 的 patch 和 batch size 为 2 的配置出发,然后根据数据集的 median shape 调整 patch 的长宽高比例。如果数据本身比这个还小,就直接用接近整幅图像的 patch,并相应增大 batch size。
每个轴上的 pooling 次数也不是固定的,而是根据 feature map 的尺寸来决定。大概就是沿着某个轴不断下采样,直到这个轴的 feature map 尺寸不能再合理继续缩小。
这个地方其实很实用。因为医学图像的尺寸不是自然图像那种比较统一的 224 × 224,而是经常出现:
512 × 512 × 几十层
512 × 512 × 几百层
几十 × 几十 × 几十
如果网络结构完全固定,就很难同时适应这些任务。
5. 预处理才是很大一部分核心
论文里的 preprocessing 包括三件比较重要的事:
cropping;
resampling;
normalization。
这部分我觉得反而是 nnU-Net 最值得学的地方之一。
5.1 Cropping
nnU-Net 会把图像裁剪到非零区域。
对于很多 CT 图像来说,这一步可能影响不大,因为背景和身体区域本来就比较完整。但对于 skull-stripped brain MRI 这种图像,非零区域外可能有大量空背景,裁剪后可以减少计算量。
这和我之前做共同有效区域裁剪的思路有点像:不是为了改变图像内容,而是为了让模型少处理没有意义的空间。
5.2 Resampling
论文里说得很直接:
CNN 本身并不知道 voxel spacing。
这句话很重要。
模型看到的是数组,但医学图像里的一个 voxel 代表多大的真实物理空间,是由 spacing 决定的。如果同一个器官在一个病例里 spacing 是 1 × 1 × 1,另一个病例里是 1 × 1 × 5,那模型直接看数组时其实并不知道它们在真实空间中的比例差异。
所以 nnU-Net 会把病例重采样到该数据集的 median voxel spacing。
其中:
图像数据使用三阶样条插值;
segmentation mask 使用最近邻插值。
标签必须用最近邻插值,这点非常关键。因为标签是类别,不是连续灰度值。如果用线性插值,标签值可能变成小数,语义就被破坏了。
这也和我目前做 DICOM 转 NIfTI、多通道输入非常相关。nnU-Net 要求图像通道空间对齐,本质上就是 size、spacing、origin、direction 这些信息要一致,否则不同通道之间的体素就不是同一个物理位置。
5.3 Normalization
nnU-Net 对 CT 和非 CT 使用不同归一化方式。
CT 的 HU 值有相对明确的物理意义,所以论文中会收集训练集 segmentation mask 内的强度值,然后:
按
0.5%和99.5%分位数进行裁剪;根据收集到的均值和标准差做 z-score normalization。
对于 MRI 或其他模态,因为强度没有统一的绝对尺度,所以更倾向于对每个 patient 单独做 z-score normalization。
这点也挺符合直觉。CT 的数值有比较稳定的含义,MRI 的数值则更依赖设备和扫描协议,所以不能用完全一样的归一化策略。
6. 训练策略
nnU-Net 的训练也没有走特别奇怪的路线。
论文里使用的是:
dice loss + cross entropy loss;
Adam optimizer;
初始学习率
3e-4;每个 epoch 定义为 250 个 training batches;
五折交叉验证;
所有模型从头训练。
其中 dice loss 适合处理医学图像分割中的类别不平衡问题,因为很多目标区域相对于整幅图像都很小。cross entropy 则提供逐体素分类监督。
数据增强方面,nnU-Net 使用了:
random rotations;
random scaling;
random elastic deformations;
gamma correction;
mirroring。
另外还有一个我觉得很实际的策略:
一个 batch 中超过三分之一的 samples 要包含至少一个随机选择的 foreground class。
因为如果完全随机采 patch,小目标任务里很容易抽到一堆背景 patch,模型训练时就会长期看到“哪里都没有目标”,这显然不太合理。
所以 patch sampling 其实也不是一个小细节,它直接决定模型能不能有效看到前景。
7. 推理策略
由于训练是 patch-based,推理也采用 patch-based。
但是 patch 边缘位置的预测通常不如中心稳定,所以 nnU-Net 在合并多个 patch 预测时,会给 patch 中心区域更高的权重。
推理时还有两个增强稳定性的做法:
相邻 patch 之间有
patch size / 2的重叠;使用 mirroring 做 test time augmentation。
也就是说,同一个 voxel 最后可能来自多个 patch、多个翻转版本的预测结果,再进行聚合。
论文里还使用五折交叉验证得到的五个模型做 ensemble,进一步提高鲁棒性。
这个地方能看出来,nnU-Net 对“稳定”这件事很执着。不是只让模型预测一次就结束,而是通过重叠 patch、中心加权、TTA、ensemble 尽量降低偶然误差。
8. 后处理
nnU-Net 的 postprocessing 也比较朴素。
它会在训练集标签上做 connected component analysis。如果某个类别在训练集里总是单个连通区域,那么推理时就会把该类别预测结果中除最大连通区域以外的部分去掉。
这个逻辑不是拍脑袋加的,而是从训练集统计出来的。
比如某个器官理论上应该只出现一个主体区域,那么模型预测出很多零散小块时,保留最大连通区域可能更合理。
但这个策略也不能乱用。如果一个类别本来就可能出现多个病灶,那么强行只保留最大连通区域就会删掉真实目标。所以 nnU-Net 的做法是先检查训练集分布,再决定是否应用。
9. 实验结果说明了什么
论文在 Medical Segmentation Decathlon 上做实验。这个挑战比较重要的一点是,它要求方法能够跨多个任务泛化,而不是只在一个数据集上调到很好。
nnU-Net 在 7 个 phase 1 任务上做五折交叉验证,然后提交 held-out test set。论文结果显示,它在多个任务和类别上取得了当时 leaderboard 上很靠前的 Dice 分数。
我觉得这里更重要的不是某一个具体 Dice 数值,而是它验证了一个思路:
一个认真自动配置的基础 U-Net pipeline,可以在很多不同医学图像任务上表现得非常强。
这其实比“我在某一个数据集上加了一个模块涨了几点”更有说服力。
10. 对我当前任务的启发
读完这篇论文之后,对我现在的任务主要有几个启发。
10.1 不要把前处理当成附属品
之前做 DICOM、registration、NIfTI 转换的时候,总觉得这些是在给模型“准备数据”。
现在感觉更准确的说法应该是:
前处理本身就是模型性能的一部分。
spacing 没统一、多通道没对齐、标签插值方式不对、裁剪区域不一致,这些问题不是训练时多跑几个 epoch 就能补回来的。
尤其是 nnU-Net 这种框架,它之所以能自动适配任务,很大程度上就是因为它把这些规则固定下来,并且让它们可复现。
10.2 多通道输入必须先保证空间一致
nnU-Net 支持多模态输入时,本质上会把不同模态或不同序列作为不同 channel。
但是 channel 之间必须对应同一个空间位置。
所以我之前一直纠结的 fixed image、moving image、resampling、共同区域裁剪,其实都是为了满足这个前提:
case_001_0000.nii.gz
case_001_0001.nii.gz
case_001_0002.nii.gz
这些文件不能只是同一个病例的不同序列,还必须在同一个 voxel index 上对应同一个物理位置。
否则模型看到的多通道信息就是错位的。
10.3 不要盲目追求复杂结构
这篇论文给我的另一个提醒是,不要太早陷入“我要不要换更复杂模型”的想法。
很多时候,模型没跑好,不一定是因为网络结构不够先进,而可能是:
数据没有整理干净;
spacing 不一致;
patch size 不合理;
foreground sampling 不充分;
normalization 不适合当前模态;
推理时没有处理 patch 边缘问题;
后处理规则和任务不匹配。
所以对当前阶段来说,与其急着改网络,不如先把 nnU-Net 需要的数据格式、预处理流程、训练配置和结果检查做好。
11. 总结
nnU-Net 给我的感觉是,它不是在回答“怎么设计一个更新的 U-Net”,而是在回答:
怎么让 U-Net 在一个新的医学图像分割任务上尽可能稳定地工作。
它的核心不是某个单独模块,而是一整套自动化 pipeline:
读取数据集特征
↓
自动决定 2D / 3D / cascade 配置
↓
裁剪、重采样、归一化
↓
patch-based training
↓
dice + CE loss
↓
数据增强和前景采样
↓
patch-based inference
↓
TTA、ensemble、后处理
所以 nnU-Net 最值得学习的地方,可能不是“它用了什么神奇网络”,而是它把医学图像分割里一整套容易出错的流程都规范化了。
对于我现在的任务来说,最重要的结论大概就是:
先把数据空间、预处理和输入格式做对,再谈模型训练。
这句话听起来很朴素,但可能确实是目前最应该记住的事情。