MONAI 专为医学AI开发的开源框架(From Nvidia)

MONAI简介:

https://blogs.nvidia.com/blog/2020/04/21/monai-open-source-framework-ai-healthcare/

https://monai.io/

https://medium.com/pytorch/monai-public-alpha-is-now-available-54b79f5532aa

https://github.com/Project-MONAI/MONAI

就先简单按源码这个结构来吧:

1 apps:

dataset.py:这个文件中定义了两个常见的数据集的Dataset类:MedNISTDataset和DecathlonDataset,继承父类Randomizable, CacheDataset,里边分别定义了随机数,和cachedataset的形式。

utils.py:定义了一些通过url下载文件,验证MD5,解压数据集的code

2 config:

deviceconfig.py:获取系统的一些版本配置信息

type_definitions.py:这个我没看太明白,貌似是给两个贯穿MONAI的概念定义了名字和类型,为了统一的使用。

定义了KeysCollection和IndexSelection

3 data:???

1 csv_saver.py
# 保存dict的结果到csv文件,预测结果之类的,可以save单个数据,也可以savebatch,finalize写入文档。有一点存疑,在overwrite为False并且存在文件时,他会先读取已有文件中的信息,保存到要写入disk的dict中,但是我担心他遇到key值相同的情况怎么办???

2 dataloader.py
# 普通的dataloader,继承自pytorch原生的,没发现什么特殊的地方目前

3 dataset.py
# 有几个自定义类以及一些应用函数
# class Dataset,继承自pytorch原生Dataset,看这一层没有什么特殊的,具体的要看transform和读取了
# class PersistentDataset,继承自自定义的Dataset,顾名思义他会将非随机性的transform(这里是读取标准化之类的操作)在第一次时做好,保存到硬盘里,后边就直接读取了,不再每次再做重复操作了。
# 注:这里用到了pathlib模块(相比os.path,nested -> chained,且os太过臃肿,还有一些其他的小区别)
# class CacheDataset,继承自自定义的Dataset,也是先处理一下非random的操作,但是这个是存在内存中的,内存大的服务器就很香了
# class ZipDataset,继承自自定义的Dataset,貌似是处理同时多个dataset的情况,这个我还没具体用过
# class ArrayDataset,继承自Randomizable和pytorch的dataset,多个数据集的话里边也用到了ZipDataset。还有些疑问我没看到他做增广啊???这个和zipdataset到时还需细看一下???

4 decathalon_datalist.py
# 感觉就是几个和十项全能数据及有关的代码

5 grid_dataset.py???
class GridPatchDataset(IterableDataset): 
#我这个阅读代码能力有些减弱了,没看太明白,但是大概意思是,把array分成块,生成出来,好像块儿与块儿之间没有overlap???
IterableDataset就没看太懂,itertool相关的东西还需要巩固。还有多线程的东西掌握也不好

6 nifti_reader.py
class NiftiDataset(Dataset, Randomizable):
# 常规载入Nifiti格式数据的一个Dataset类,里边比较复杂的是用到了monai.transforms 里边的 LoadNifti

7 nifti_saver.py
class NiftiSaver:
# 一个保存数据为Nifti格式的类。支持的输入数据格式可以为单个数据,也可以是一个batch的数据。一般来说保存的都是分割的预测结果。保存时用到了write_nifti函数。

8 nifti_writer.py
def write_nifti(...):
# 考虑了几种情况,是不是需要affine;channel调整的问题;

9 png_saver.py
10 png_writer.py
# 类似于Nifti格式,使用的PIL包来处理

11 synthetic.py
# 。。。我看是生成噪声图 和 一堆重叠圆的test图。。。没发现什么实际意义

12 utils.py ???
# 很多小工具,还没细看???

 

11 synthetic 效果

4 engines: 总的来说是宏观上和训练测试有关的代码.

1 trainer.py
a. class Trainer(Workflow):
# 所有trainer的基类,继承于Workflow
# def run(self)基于Ignite Engine训练

b. class SupervisedTrainer(Trainer):
# 标准???的监督训练方式

2 workflow.py
class Workflow(IgniteEngine):
# 一个  训练相关的 类

3 multi_gpu_supervised_trainer.py
# 多gpu,继承ignite的,其实就是多了个Dataparallel...

4 evaluater.py
class Evaluator(Workflow):
# 类似trainer.py

5 handlers

1 checkpoint_loader.py
class CheckpointLoader:
# CheckpointLoader acts as an Ignite handler to load checkpoint data from file.
# It can load variables for network, optimizer, lr_scheduler, etc.

2 checkpoint_saver.py
# 就一些保存的不同情况设置

3 classification_saver.py
# 使用CSVSaver将分类结果保存至csv文件,以及一些保存设置,是否覆盖已有balabala ..

4 lr_schedule_handler.py???
# lr相关的handler,handler的概念仍需掌握??? 都有个attach Ignite 的 event里边???

5 mean_dice.py
class MeanDice(Metric): # ignite.metrics里的类Metric
# 就是计算dice,一些操作加了ignite中的@reinit__is_reduced(修饰器)

6 metric_logger.py
class MetricLogger:
# 记录loss和metric,也有attach,接收engine,看来和其他的handler差不多,就是名字不一样

7 roc_auc.py
class ROCAUC(Metric):
# 计算roc auc的类,调用rocauc.py文件里的compute_roc_auc函数也是

8 segmentation_saver.py
class SegmentationSaver:
# 分割结果保存的代码,里边调用了前边data文件夹中的NiftiSaver和PNGSaver类;

9 stats_handler.py
class StatsHandler(object):
# 负责一些log打印的逻辑,如果没有特别指定epoch_print_logger或者iteration_print_logger,会使用默认的_default_epoch_print或者_default_iteration_print标准形式来打印。

10 tensorboard_handlers.py
a. class TensorBoardStatsHandler(object):
# 类似于stats_handler.py,也是tensorboard相关记录的逻辑,如果没有特殊指定,会自动使用default的记录形式

b. class TensorBoardImageHandler(object):
# 和上述不同的是,a是记录数值的,这个是记录Image可视化的
# 2D会显示batch中的第一个数据,3D会以gif形式显示后三个维度

11 utils.py
# 分别根据metric和loss指标 提前终止实验的func

12 validation_handler.py
class ValidationHandler:
# 将validator attach 到 trainer上,每N个epochs或者N个iterations进行一次验证
# 注意这里只是个validator的训练配置函数,非Evaluator ignite engine的逻辑实现

6 inferers

1 inferer.py
a. class Inferer(ABC):
# 模型inference的基类,

b. class SimpleInferer(Inferer):
# 最简单的inference,直接测

c. class SlidingWindowInferer(Inferer):
# emm 一个sliding window的类,包含一些参数设置,但是不涉及实现,调用的utils.py中的sliding_window_inference

2 utils.py
# emm sliding_window_inference的实现,就是那种实现
# 目前函数只支持batch=1的sliding winfow batch_size
# 另:MONAI的数据格式支持问题:除了是Nifti,数据的顺序为BCHWD而不是常用的BCDHW

# todo: 这里有个问题需要注意啊,MONAI中貌似对3D数据格式的设置都是HWD而不是常规的DHW.(仍需确认???)

7 losses

1 dice.py
a. class DiceLoss(_Loss):
# 数据格式为BCHWD
# 设置了计算loss时候可以去除背景的计算,因为在针对一些小目标时,将背景引入进来会淹没前景的训练,所以去除有助于收敛
# [个人疑问:去除背景会不会导致假阳性的出现,这时候可能就只能依靠entropy类的loss来学习背景了;另外,如果类间直接平均也还好吧,如果大家有其他想法欢迎指教~]
# 普通的dice计算功能

b. class MaskedDiceLoss(DiceLoss):
# 和diceloss计算一样,只不过加了个binary的mask来限制计算的区域

c. class GeneralizedDiceLoss(_Loss):
# 出自这篇文章:Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations. DLMIA 2017.
# 相比diceloss加了权重,gt前景范围的倒数,来平衡不同尺度前景的diceloss 的权重.


2 focal_loss.py
class FocalLoss(_WeightedLoss):
# CrossEctropyLoss也是_WeightedLoss的子类
# focalloss 的实现,可以详细看一下过程,毕竟官方segmentation focal loss实现,还是要优雅的

3 tversky.py
class TverskyLoss(_Loss):
# 1-tp/(tp+fp+fn) 其中fp和fn有权重,算是前景背景不平衡时均衡的一种loss吧

8 metrics

1 meandice.py
class DiceMetric:
# It can support both multi-classes and multi-labels tasks.
# 嗯

2 rocauc.py
# compute_roc_auc的实现,会被调用

9 networks

### blocks ###
1 aspp.py
class SimpleASPP(nn.Module):

......未完待续

有他自己的写法,需要详细看一下

10 transforms

### croppad ###
1 array.py
a. class SpatialPad(Transform):
b. class BorderPad(Transform):
c. class DivisiblePad(Transform):
d. class SpatialCrop(Transform):
e. class CenterSpatialCrop(Transform):
f. class RandSpatialCrop(Randomizable, Transform):
g. class RandSpatialCropSamples(Randomizable, Transform):
h. class CropForeground(Transform):
i. class RandCropByPosNegLabel(Randomizable, Transform):

2 dictionary.py
a. class SpatialPadd(MapTransform):
b. class BorderPadd(MapTransform):
c. class DivisiblePadd(MapTransform):
d. class SpatialCropd(MapTransform):
e. class CenterSpatialCropd(MapTransform):
f. class RandSpatialCropd(Randomizable, MapTransform):
g. class RandSpatialCropSamplesd(Randomizable, MapTransform):
h. class CropForegroundd(MapTransform):
i. class RandCropByPosNegLabeld(Randomizable, MapTransform):
### intensity ###
1 array.py
a. class RandGaussianNoise(Randomizable, Transform):
b. class ShiftIntensity(Transform):
c. class RandShiftIntensity(Randomizable, Transform):
d. class ScaleIntensity(Transform):
e. class RandScaleIntensity(Randomizable, Transform):
f. class NormalizeIntensity(Transform):
g. class ThresholdIntensity(Transform):
h. class ScaleIntensityRange(Transform):
i. class AdjustContrast(Transform):
j. class RandAdjustContrast(Randomizable, Transform):
k. class ScaleIntensityRangePercentiles(Transform):
l. class MaskIntensity(Transform):

2 dictionary.py
# 功能同上述array.py中的类
### io ###
1 array.py
a. class LoadNifti(Transform):
b. class LoadPNG(Transform):
c. class LoadNumpy(Transform):

2 dictionary.py
# 功能类似上
# 区别在于返回的是dict
### post ###
1 array.py
a. class SplitChannel(Transform):
# return list

b. class Activations(Transform):
# 给模型输出加上激活

c. class AsDiscrete(Transform):

d. class KeepLargestConnectedComponent(Transform):

e. class LabelToContour(Transform):

2 dictionary.py
# 功能类似上述,返回d
### spatial ###
1 array.py
class Spacing(Transform):
class Orientation(Transform):
class Flip(Transform):
class Resize(Transform):
class Rotate(Transform): # 和orientation区别什么???确认一下???
class Zoom(Transform):
class Rotate90(Transform):
class RandRotate90(Randomizable, Transform):
class RandRotate(Randomizable, Transform):
class RandFlip(Randomizable, Transform):
class RandZoom(Randomizable, Transform):

class AffineGrid(Transform):
class RandAffineGrid(Randomizable, Transform):
class RandDeformGrid(Randomizable, Transform):
class Resample(Transform):
class Affine(Transform):
class RandAffine(Randomizable, Transform):
class Rand2DElastic(Randomizable, Transform):
class Rand3DElastic(Randomizable, Transform):

2 dictionary.py
# 全不全不知道,类似上述.
### utils.py ###
# 一些 小工具
# 然后
adaptors.py????????????????????????????????????
compose.py
utils.py

11 utils

1 aliases.py
def alias(*names):???????????????????????????
def resolve_name(name):?????????

2 decorators.py
a. timing # 用于记录func的时间
b. class RestartGenerator:
c. class MethodReplacer(object):????????????????????????????

3 enums.py
# 都是Enum的subclass.

4 misc.py
# miscellaneous顾名思义杂项

5 module.py

12 visualize

可视化的一些,3D数据顺序HWD

~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~

Ignite相关简介:

其中MONAI很大一部分代码基于ignite格式编写, ignite框架的核心基础为class Engine, 如此便可以实现简单的训练和validation策略了;

为了Engine的灵活性. event system被引入用于促进每一步的交互性:

  • engine is started/completed
  • epoch is started/completed
  • batch iteration is started/completed

所有的Event列表可见:Events

用户可以自定义code作为event handler执行, handler的定义形式不做要求.

########

让我们通过一个例子来了解当run运行时, 发生了什么细节:

fire_event(Events.STARTED)
while epoch < max_epochs:
    fire_event(Events.EPOCH_STARTED)
    # run once on data
    for batch in data:
        fire_event(Events.ITERATION_STARTED)

        output = process_function(batch)

        fire_event(Events.ITERATION_COMPLETED)
    fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)

当一个事件被触发(event is fired), 这个event对应的handlers都会被执行. 添加handler很简单,  add_event_handler() or on() decorator都可以实现.

#######

ignite提供了一系列内部handlers,可以参考ignite.handlersignite.contrib.handlers

########

state:

A state is introduced in Engine to store the output of the process_function, current epoch, iteration and other helpful information. Each Engine contains a State, which includes the following:

  • engine.state.seed: Seed to set at each data “epoch”.
  • engine.state.epoch: Number of epochs the engine has completed. Initializated as 0 and the first epoch is 1.
  • engine.state.iteration: Number of iterations the engine has completed. Initialized as 0 and the first iteration is 1.
  • engine.state.max_epochs: Number of epochs to run for. Initializated as 1.
  • engine.state.output: The output of the process_function defined for the Engine. See below.
  • etc

Other attributes can be found in the docs of State.

########

~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~·~

学习到的散装知识点:

1 tuple是可以比较大小的,code中版本比较是这样做的

# 提取版本前两位到tuple,判断版本是否满足需求
(1,4)>(1,3)
return True # 1.4版本大于1.3版本

2 package 和Module的范围,以及namespace module

从本质上讲,包命名空间(namespace package)是一种特殊的封装设计,为合并不同的目录的代码到一个共同的命名空间

关键是确保顶级目录中没有__init__.py文件来作为共同的命名空间

例如:foo-package and bar-package are two different dir path, but they all contain "spam" (without __init__.py)

3 np.random.RandomState() 和np.random.seed()区别

np.random.RandomState()可以构造一个随机数生成器,他对独立功能np.random.没有影响,不影响整体

In [44]: np.random.seed(20)

In [45]: np.random.uniform(0,10,5)
Out[45]: array([5.88130801, 8.97713728, 8.91530729, 8.15837477, 0.35889586])

In [46]: np.random.rand(2,3)
Out[46]: 
array([[0.69175758, 0.37868094, 0.51851095],
       [0.65795147, 0.19385022, 0.2723164 ]])

In [47]: r=np.random.RandomState(20)

In [48]: r.uniform(0,10,5)
Out[48]: array([5.88130801, 8.97713728, 8.91530729, 8.15837477, 0.35889586])

In [49]: r.rand(2,3)
Out[49]: 
array([[0.69175758, 0.37868094, 0.51851095],
       [0.65795147, 0.19385022, 0.2723164 ]])


In [52]: np.random.randn(4)
Out[52]: array([0.91635593, 0.70783847, 0.41967613, 0.53415759])

In [53]: r.randn(4)
Out[53]: array([0.91635593, 0.70783847, 0.41967613, 0.53415759])

4 print 和sys.stdout的关系

print默认调用sys.stdout输出当前输出面板,见print的源代码,默认使用file的write方法

def print(self, *args, sep=' ', end='\n', file=None): # known special case of print
    """
    print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)
    
    Prints the values to a stream, or to sys.stdout by default.
    Optional keyword arguments:
    file:  a file-like object (stream); defaults to the current sys.stdout.
    sep:   string inserted between values, default a space.
    end:   string appended after the last value, default a newline.
    flush: whether to forcibly flush the stream.
    """
    pass

若我们修改file就可以直接使用print输出到我们的文件中,例

In [8]: print('eric love kani',file=open('aa.txt',"a"))

In [9]: print('eric love kani 3000',file=open('aa.txt',"a"))

In [10]: print('eric love kani 3000')
eric love kani 3000

其中flush变量为True的时候会立刻输出,为False会攒一大波,一起输出,目前测试如此。但是pycharm有时候不按套路出牌。。。

5 我终于知道cval的全称了23333

"""
    cval: fill value for 'constant' padding mode. Default: 0
"""

未完待续......

猜你喜欢

转载自blog.csdn.net/Eric_Evil/article/details/107513786