TensorflowやKerasでResNeXt50, ResNeXt101を使う

ResNetの改良版であるResNextは、パラメータ数はほぼ同じなのに精度が向上している。

試したくなったが、TensorflowやKerasだとデフォルトではkeras.applicationsに入っていないので使えない。

githubのソースコードを辿っていくとResNeXtと思われる実装があったので、流用して使えるようにしてみた。

resnetのソースコードを確認。

File not found · keras-team/keras
Deep Learning for humans. Contribute to keras-team/keras development by creating an account on GitHub.
File not found · tensorflow/tensorflow
An Open Source Machine Learning Framework for Everyone - File not found · tensorflow/tensorflow

WEIGHTS_HASHES = {
    'resnet50': ('2cb95161c43110f7111970584f804107',
                 '4d473c1dd8becc155b73f8504c6f6626'),
    'resnet101': ('f1aeb4b969a6efcfb50fad2f0c20cfc5',
                  '88cf7a10940856eca736dc7b7e228a21'),
    'resnet152': ('100835be76be38e30d865e96f2aaae62',
                  'ee4c566cf9a93f14d82f913c2dc6dd0c'),
    'resnet50v2': ('3ef43a0b657b3be2300d5770ece849e0',
                   'fac2f116257151a9d068a22e544a4917'),
    'resnet101v2': ('6343647c601c52e1368623803854d971',
                    'c0ed64b8031c3730f411d2eb4eea35b5'),
    'resnet152v2': ('a49b44d1979771252814e80f8ec446f9',
                    'ed17cf2e0169df9d443503ef94b23b33'),
    'resnext50': ('67a5b30d522ed92f75a1f16eef299d1a',
                  '62527c363bdd9ec598bed41947b379fc'),
    'resnext101':
        ('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3')
}

imagenetの重みがある!

しかし、肝心のResNeXtが定義されていない。

keras-applicationsのリポジトリを見てみる。

keras-applications/keras_applications/resnet_common.py at master · keras-team/keras-applications
Reference implementations of popular deep learning models. - keras-team/keras-applications

# 543 行目
def ResNeXt50(...):
   (中略)

# 563 行目
def ResNeXt101(...):
   (中略)

# 最下段 589 行目
setattr(ResNeXt50, '__doc__', ResNet.__doc__)
setattr(ResNeXt101, '__doc__', ResNet.__doc__)

resnextの実装があった!

ResNeXt50, ResNeXt101, setattr()のコードをコピーして、自分のマシンのTensorflowやKerasのソースコードに貼り付ける。

例えばTensorflowの場合、自分はpyenvでpy384というvirtualenvを使っていたので、以下のソースコードを開く。

~/.pyenv/versions/3.8.4/envs/py384/lib/python3.8/site-packages/tensorflow/python/keras/applications/resnet.py

resnextのコードをコピペ。

他のresnetにならって、@keras_exportでデコレートし、setattr()も微修正。

@keras_export('keras.applications.resnet.ResNeXt50',
              'keras.applications.ResNeXt50')
def ResNeXt50(include_top=True,
              weights='imagenet',
              input_tensor=None,
              input_shape=None,
              pooling=None,
              classes=1000,
              **kwargs):
    def stack_fn(x):
        x = stack3(x, 128, 3, stride1=1, name='conv2')
        x = stack3(x, 256, 4, name='conv3')
        x = stack3(x, 512, 6, name='conv4')
        x = stack3(x, 1024, 3, name='conv5')
        return x

    return ResNet(stack_fn, False, False, 'resnext50',
                  include_top, weights,
                  input_tensor, input_shape,
                  pooling, classes,
                  **kwargs)


@keras_export('keras.applications.resnet.ResNeXt101',
              'keras.applications.ResNeXt101')
def ResNeXt101(include_top=True,
               weights='imagenet',
               input_tensor=None,
               input_shape=None,
               pooling=None,
               classes=1000,
               **kwargs):
    def stack_fn(x):
        x = stack3(x, 128, 3, stride1=1, name='conv2')
        x = stack3(x, 256, 4, name='conv3')
        x = stack3(x, 512, 23, name='conv4')
        x = stack3(x, 1024, 3, name='conv5')
        return x

    return ResNet(stack_fn, False, False, 'resnext101',
                  include_top, weights,
                  input_tensor, input_shape,
                  pooling, classes,
                  **kwargs)

setattr(ResNeXt50, '__doc__', ResNet.__doc__ + DOC)
setattr(ResNeXt101, '__doc__', ResNet.__doc__ + DOC)

site-packages/tensorflow/keras/applications/resnet/init.py にimport文を追記。

from tensorflow.python.keras.applications.resnet import ResNeXt50
from tensorflow.python.keras.applications.resnet import ResNeXt101

無事に動くかな?

from tensorflow.keras.applications.resnet import ResNeXt50

model = ResNeXt50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
model.summary()



Model: "resnext50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_4[0][0]                    

(中略)

avg_pool (GlobalAveragePooling2 (None, 2048)         0           conv5_block3_out[0][0]           
__________________________________________________________________________________________________
predictions (Dense)             (None, 1000)         2049000     avg_pool[0][0]                   
==================================================================================================
Total params: 25,097,128
Trainable params: 25,028,904
Non-trainable params: 68,224

これでResNeXt50, 101がTensorflowで使えるようになった。

コメント

タイトルとURLをコピーしました