stringのTFRecordを作る時にTypeError

猫の画像とラベルを含んだTFRecordを作ろうとして表題のエラー。

TypeError: 'cat' has type str, but expected one of: bytes

使ったコードは公式チュートリアルに記載のもの。

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

_bytes_feature('cat')

Returns a bytes_list from a string / byte と書いてあるが、strを与えるとbytesにしろと怒られてしまった。

チュートリアルのちょっと下にヒントを発見。

Below are some examples of how these functions work. 
Note the varying input types and the standardized output types. 
If the input type for a function does not match one of the coercible types stated above, the function will raise an exception (e.g. _int64_feature(1.0) will error out, since 1.0 is a float, so should be used with the _float_feature function instead):

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

utf-8に明示的にエンコードすれば良いようだ。

_bytes_feature('cat'.encode('utf-8'))

これでエラーは出なくなり、無事に画像とラベルを含んだTFRecordを作成することが出来た。

参考

TFRecord and tf.train.Example  |  TensorFlow Core

コメント

タイトルとURLをコピーしました