Python中的注册器模块
问题:
在使用BasicSR
的时候遇到了动态加载模型的方法,这个是一个很使用的方式,因为在实验过程中,我们不可避免的去写很多模型类,每一次都需要修改build_model
代码中的import **model as Network
这会给代码维护以及修改带来很大的困难。
如果我们只用在外部维护对应实验的.yml
文件该文件中包含了模型类的申明,那么每一个实验对应不同的.yml
文件,代码内部import
的流程我们则不需要去关心以及修改了。
如何解决:
关于我之前一直忽略的__init_
文件
我在这之前从未关心过每个文件夹下面的__init__
文件的用法,这个是每次注册类、import类的最开始执行的文件,具体来说但凡文件入口(train.py
)运行到具体的from file_name import *
指令的时候,就会执行该文件下的__init__
1 | from basicsr.data import build_dataloader, build_dataset |
在basicsr/archs/__init__
中会import
所有的 arch.py
文件:
1 | arch_folder = osp.dirname(osp.abspath(__file__)) |
是怎么import
各个模型是需要注意的,具体的,采用了修饰器进行导入。猜测 import_module
会调用每一个文件中静态注册函数:@ARCH_REGISTRY.register()
, 并且进行import
1 | from basicsr.utils.registry import ARCH_REGISTRY |
修饰器的用法等于:
1 |
|
在这个Registry
类中保留了所有注册的类以及其类名:明确来说:由
1 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] |
这条指令得到的在basicsr/archs/
文件下由 _arch.py
结尾的文件吗名称都会作为self._obj_map[name]
待所有类都被导入后,使用net = ARCH_REGISTRY.get(network_type)(**opt)
生成具体类实例。
怎么使用:
如果 model 不变则可以只增加_arch.py
注意命名必须要以 _arch.py
结尾。yml 文件中
1 | network_g: |
改为 arch 的名称即可。