在众多深度学习开源库的代码中经常出现Registry代码块,例如OpenMMlabfacebookresearchBasicSR中都使用了注册器机制。这块的代码经常会让新使用这些库的初学者感到一头雾水,本篇博客来分析一下注册器机制的原理与好处。


1. 为什么使用registry

在讲解registry原理前,我们先介绍一下,为何使用registry。registry的中文翻译是注册器。对于一个好用的深度学习代码库来说,通常都会内置多种损失函数,多种网络结构,以及多种优化器等。同时这类的库一般都支持从配置文件中,直接解析出模型结构与训练策略。那么如何优雅的从配置文件解析到具体的代码实现呢?这就是引入注册操作的意义,简而言之,注册器是为了方便找到相关模块。

2. registry代码阅读

在实现上不同代码库略有差异,但原理相同,所以这里就以BasicSR为例。

class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
""" def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {} def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class, suffix)
return func_or_class return deco # used as a function call
name = obj.__name__
self._do_register(name, obj, suffix) def get(self, name, suffix='basicsr'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret def __contains__(self, name):
return name in self._obj_map def __iter__(self):
return iter(self._obj_map.items()) def keys(self):
return self._obj_map.keys() DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')

上面的代码为数据集,架构,网络,损失以及度量方式都创建了一个注册器对象。核心代码在register函数里,register函数使用了装饰器的设计,也就是只要在功能模块前进行@xx.register()进行装饰,就会对原有功能模块进行注册,并且最终返回原始的功能模块,不修改其原有功能。

在更下层的_do_register()可以看到,这里使用的是一个字典来执行注册操作,记录的键值对分别是模块的名称以及模块本身。这样一来,读取配置文件中的模块字符串后,我们就能够直接通过函数名或者类名找到其具体实现。

@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
""" def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight
self.reduction = reduction def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)

最新文章

  1. Python开发【第五篇】:Python基础之杂货铺
  2. java判定字符串中仅有数字和- 正则表达式匹配 *** 最爱那水货
  3. loadrunner简单的例子(demo)
  4. 另类的package-info.java文件探讨
  5. IOS 添加到通讯录
  6. 蜗牛爱课 - iOS7、8模态半透明弹出框
  7. linq中的cast<T>()及OfType<T>()
  8. CSS3学习系列之背景相关样式(二)
  9. 行为驱动:Cucumber + Selenium + Java(一) - 环境搭建
  10. python基础16_闭包_装饰器
  11. IP通信基础课堂笔记----以太网VLAN
  12. Ubuntu下安装pytorch(GPU版)
  13. 异常: Call From * 9000 failed on connection exception: java.net.ConnectException: Connection refused: no further information; For more details see: http://wiki.apache.org/hadoop/ConnectionRefused
  14. 【mongoDB高级篇①】聚集运算之group与aggregate
  15. 成员变量位置获取url
  16. 【IDEA&&Eclipse】2、从Eclipse转移到IntelliJ IDEA一点心得
  17. hive-site.xml
  18. Mysql 数据库 创建与删除(基础2)
  19. 如何在windows下安装JDK
  20. Bootstrap+Angularjs自制弹框

热门文章

  1. 【在下版本,有何贵干?】Dockerfile中 RUN yum -y install vim失败Cannot prepare internal mirrorlist: No URLs in mirrorlist
  2. jquery 动态 给select赋值
  3. JS 中 对象 基础认识
  4. MySQL实时在线备份恢复方案
  5. 浅谈 UNIX、Linux、ios、android 他们之间的关系
  6. Java 14中对switch的增强,终于可以不写break了
  7. Oracle 19c单实例部署
  8. CSAPP 之 AttackLab 详解
  9. 一文学会Java的交互式编程环境jshell
  10. Random 中的Seed