将自己的数据集制作成TFRecord格式教程-创新互联
在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格
式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。
1. 原本的数据集
此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。

2.制作成TFRecord格式
tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。
#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
#生成字符串类型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
#制作TFRecord格式
def createTFRecord(filename,mapfile):
class_map = {}
data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
classes = {'xiansu60','xiansu100'}
#输出TFRecord文件的地址
writer = tf.python_io.TFRecordWriter(filename)
for index,name in enumerate(classes):
class_path=data_dir+name+'/'
class_map[index] = name
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每个图片的地址
img = Image.open(img_path)
img= img.resize((224,224))
img_raw = img.tobytes() #将图片转化成二进制格式
example = tf.train.Example(features = tf.train.Features(feature = {
'label':_int64_feature(index),
'image_raw': _bytes_feature(img_raw)
}))
writer.write(example.SerializeToString())
writer.close()
txtfile = open(mapfile,'w+')
for key in class_map.keys():
txtfile.writelines(str(key)+":"+class_map[key]+"\n")
txtfile.close()
本文名称:将自己的数据集制作成TFRecord格式教程-创新互联
标题路径:http://www.jxjierui.cn/article/jggjd.html


咨询
建站咨询
