Home > AI > Uncategorized

TensorFlow – TfRecord2

This example uses Mnist dataset and convert one train data(image, label, shape) to tfrecord.

You will learn how to write and read tfrecord file and three features:

tf.train.BytesList(value=[])

tf.train.Int64List(value=[])

tf.train.FloatList(value=[])

After completing this, you can think how to convert all train data and test data?


# write and read one entry of minst_train

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('../../Model/Mnist_data', one_hot=True)

# prepare data
image = mnist.train.images[0]
label = mnist.train.labels[0]
shape = image.shape

prefix = './tmp/mnist_'
suffix = '.tfrecords'
type = 'train'

filename = prefix+type+suffix


# write tfrecords
features = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tostring()])),
'label': tf.train.Feature(float_list=tf.train.FloatList(value=label)),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=shape))
}

example = tf.train.Example(features=tf.train.Features(feature=features))
writer = tf.python_io.TFRecordWriter(filename)
writer.write(example.SerializeToString())



# read tfrecords
feature_stencil = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([10], tf.float32),
'shape': tf.FixedLenFeature([],tf.int64)
}

feats = tf.parse_single_example(example.SerializeToString(), feature_stencil)
d_image = tf.decode_raw(feats['image'], tf.float32)
d_label = feats['label']
d_shape = feats['shape']

se = tf.Session()
print(se.run(d_image))
print('\n\n')
print(se.run(d_label))
print('\n\n')
print(se.run(d_shape))

 

Related posts:

Leave a Reply