スポンサーリンク

TensorflowやKerasでResNeXt50, ResNeXt101を使う

ResNetの改良版であるResNextは、パラメータ数はほぼ同じなのに精度が向上している。

試したくなったが、TensorflowやKerasだとデフォルトではkeras.applicationsに入っていないので使えない。

githubのソースコードを辿っていくとResNeXtと思われる実装があったので、流用して使えるようにしてみた。

resnetのソースコードを確認。

keras/resnet.py at master · keras-team/keras
Deep Learning for humans. Contribute to keras-team/keras development by creating an account on GitHub.
Build software better, together
GitHub is where people build software. More than 94 million people use GitHub to discover, fork, and contribute to over 330 million projects.

WEIGHTS_HASHES = {
    'resnet50': ('2cb95161c43110f7111970584f804107',
                 '4d473c1dd8becc155b73f8504c6f6626'),
    'resnet101': ('f1aeb4b969a6efcfb50fad2f0c20cfc5',
                  '88cf7a10940856eca736dc7b7e228a21'),
    'resnet152': ('100835be76be38e30d865e96f2aaae62',
                  'ee4c566cf9a93f14d82f913c2dc6dd0c'),
    'resnet50v2': ('3ef43a0b657b3be2300d5770ece849e0',
                   'fac2f116257151a9d068a22e544a4917'),
    'resnet101v2': ('6343647c601c52e1368623803854d971',
                    'c0ed64b8031c3730f411d2eb4eea35b5'),
    'resnet152v2': ('a49b44d1979771252814e80f8ec446f9',
                    'ed17cf2e0169df9d443503ef94b23b33'),
    'resnext50': ('67a5b30d522ed92f75a1f16eef299d1a',
                  '62527c363bdd9ec598bed41947b379fc'),
    'resnext101':
        ('34fb605428fcc7aa4d62f44404c11509', '0f678c91647380debd923963594981b3')
}

imagenetの重みがある!

しかし、肝心のResNeXtが定義されていない。

keras-applicationsのリポジトリを見てみる。

keras-applications/resnet_common.py at master · keras-team/keras-applications
Reference implementations of popular deep learning models. - keras-applications/resnet_common.py at master · keras-team/keras-applications

# 543 行目
def ResNeXt50(...):
   (中略)

# 563 行目
def ResNeXt101(...):
   (中略)

# 最下段 589 行目
setattr(ResNeXt50, '__doc__', ResNet.__doc__)
setattr(ResNeXt101, '__doc__', ResNet.__doc__)

resnextの実装があった!

ResNeXt50, ResNeXt101, setattr()のコードをコピーして、自分のマシンのTensorflowやKerasのソースコードに貼り付ける。

例えばTensorflowの場合、自分はpyenvでpy384というvirtualenvを使っていたので、以下のソースコードを開く。

~/.pyenv/versions/3.8.4/envs/py384/lib/python3.8/site-packages/tensorflow/python/keras/applications/resnet.py

resnextのコードをコピペ。

他のresnetにならって、@keras_exportでデコレートし、setattr()も微修正。

@keras_export('keras.applications.resnet.ResNeXt50',
              'keras.applications.ResNeXt50')
def ResNeXt50(include_top=True,
              weights='imagenet',
              input_tensor=None,
              input_shape=None,
              pooling=None,
              classes=1000,
              **kwargs):
    def stack_fn(x):
        x = stack3(x, 128, 3, stride1=1, name='conv2')
        x = stack3(x, 256, 4, name='conv3')
        x = stack3(x, 512, 6, name='conv4')
        x = stack3(x, 1024, 3, name='conv5')
        return x

    return ResNet(stack_fn, False, False, 'resnext50',
                  include_top, weights,
                  input_tensor, input_shape,
                  pooling, classes,
                  **kwargs)


@keras_export('keras.applications.resnet.ResNeXt101',
              'keras.applications.ResNeXt101')
def ResNeXt101(include_top=True,
               weights='imagenet',
               input_tensor=None,
               input_shape=None,
               pooling=None,
               classes=1000,
               **kwargs):
    def stack_fn(x):
        x = stack3(x, 128, 3, stride1=1, name='conv2')
        x = stack3(x, 256, 4, name='conv3')
        x = stack3(x, 512, 23, name='conv4')
        x = stack3(x, 1024, 3, name='conv5')
        return x

    return ResNet(stack_fn, False, False, 'resnext101',
                  include_top, weights,
                  input_tensor, input_shape,
                  pooling, classes,
                  **kwargs)

setattr(ResNeXt50, '__doc__', ResNet.__doc__ + DOC)
setattr(ResNeXt101, '__doc__', ResNet.__doc__ + DOC)

site-packages/tensorflow/keras/applications/resnet/init.py にimport文を追記。

from tensorflow.python.keras.applications.resnet import ResNeXt50
from tensorflow.python.keras.applications.resnet import ResNeXt101

無事に動くかな?

from tensorflow.keras.applications.resnet import ResNeXt50

model = ResNeXt50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
model.summary()



Model: "resnext50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_4[0][0]                    

(中略)

avg_pool (GlobalAveragePooling2 (None, 2048)         0           conv5_block3_out[0][0]           
__________________________________________________________________________________________________
predictions (Dense)             (None, 1000)         2049000     avg_pool[0][0]                   
==================================================================================================
Total params: 25,097,128
Trainable params: 25,028,904
Non-trainable params: 68,224

これでResNeXt50, 101がTensorflowで使えるようになった。

コメント

タイトルとURLをコピーしました