ResNetの改良版であるResNextは、パラメータ数はほぼ同じなのに精度が向上している。
試したくなったが、TensorflowやKerasだとデフォルトではkeras.applicationsに入っていないので使えない。
githubのソースコードを辿っていくとResNeXtと思われる実装があったので、流用して使えるようにしてみた。
resnetのソースコードを確認。
File not found · keras-team/keras
Deep Learning for humans. Contribute to keras-team/keras development by creating an account on GitHub.
File not found · tensorflow/tensorflow
An Open Source Machine Learning Framework for Everyone - File not found · tensorflow/tensorflow
WEIGHTS_HASHES = {
'resnet50': ('2cb95161c43110f7111970584f804107',
'4d473c1dd8becc155b73f8504c6f6626'),
'resnet101': ('f1aeb4b969a6efcfb50fad2f0c20cfc5',
'88cf7a10940856eca736dc7b7e228a21'),
'resnet152': ('100835be76be38e30d865e96f2aaae62',
'ee4c566cf9a93f14d82f913c2dc6dd0c'),
'resnet50v2': ('3ef43a0b657b3be2300d5770ece849e0',
'fac2f116257151a9d068a22e544a4917'),
'resnet101v2': ('6343647c601c52e1368623803854d971',
'c0ed64b8031c3730f411d2eb4eea35b5'),
'resnet152v2': ('a49b44d1979771252814e80f8ec446f9',
'ed17cf2e0169df9d443503ef94b23b33'),
'resnext50': ('67a5b30d522ed92f75a1f16eef299d1a',
'62527c363bdd9ec598bed41947b379fc'),
'resnext101':
('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3')
}
imagenetの重みがある!
しかし、肝心のResNeXtが定義されていない。
keras-applicationsのリポジトリを見てみる。
keras-applications/keras_applications/resnet_common.py at master · keras-team/keras-applications
Reference implementations of popular deep learning models. - keras-team/keras-applications
# 543 行目
def ResNeXt50(...):
(中略)
# 563 行目
def ResNeXt101(...):
(中略)
# 最下段 589 行目
setattr(ResNeXt50, '__doc__', ResNet.__doc__)
setattr(ResNeXt101, '__doc__', ResNet.__doc__)
resnextの実装があった!
ResNeXt50, ResNeXt101, setattr()のコードをコピーして、自分のマシンのTensorflowやKerasのソースコードに貼り付ける。
例えばTensorflowの場合、自分はpyenvでpy384というvirtualenvを使っていたので、以下のソースコードを開く。
~/.pyenv/versions/3.8.4/envs/py384/lib/python3.8/site-packages/tensorflow/python/keras/applications/resnet.py
resnextのコードをコピペ。
他のresnetにならって、@keras_exportでデコレートし、setattr()も微修正。
@keras_export('keras.applications.resnet.ResNeXt50',
'keras.applications.ResNeXt50')
def ResNeXt50(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
**kwargs):
def stack_fn(x):
x = stack3(x, 128, 3, stride1=1, name='conv2')
x = stack3(x, 256, 4, name='conv3')
x = stack3(x, 512, 6, name='conv4')
x = stack3(x, 1024, 3, name='conv5')
return x
return ResNet(stack_fn, False, False, 'resnext50',
include_top, weights,
input_tensor, input_shape,
pooling, classes,
**kwargs)
@keras_export('keras.applications.resnet.ResNeXt101',
'keras.applications.ResNeXt101')
def ResNeXt101(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
**kwargs):
def stack_fn(x):
x = stack3(x, 128, 3, stride1=1, name='conv2')
x = stack3(x, 256, 4, name='conv3')
x = stack3(x, 512, 23, name='conv4')
x = stack3(x, 1024, 3, name='conv5')
return x
return ResNet(stack_fn, False, False, 'resnext101',
include_top, weights,
input_tensor, input_shape,
pooling, classes,
**kwargs)
setattr(ResNeXt50, '__doc__', ResNet.__doc__ + DOC)
setattr(ResNeXt101, '__doc__', ResNet.__doc__ + DOC)
site-packages/tensorflow/keras/applications/resnet/init.py にimport文を追記。
from tensorflow.python.keras.applications.resnet import ResNeXt50
from tensorflow.python.keras.applications.resnet import ResNeXt101
無事に動くかな?
from tensorflow.keras.applications.resnet import ResNeXt50
model = ResNeXt50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
model.summary()
Model: "resnext50"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_4[0][0]
(中略)
avg_pool (GlobalAveragePooling2 (None, 2048) 0 conv5_block3_out[0][0]
__________________________________________________________________________________________________
predictions (Dense) (None, 1000) 2049000 avg_pool[0][0]
==================================================================================================
Total params: 25,097,128
Trainable params: 25,028,904
Non-trainable params: 68,224
これでResNeXt50, 101がTensorflowで使えるようになった。
コメント