【教程】使用PAI深度学习tensorflow读取OSS教程
转载需注明出处
1. 如何PAI上读取数据
2. 如何减少读取的费用开支
3. 使用OSS需要注意的问题
1. 在PAI上读取数据
Python不支持读取oss的数据, 故所有调用 python Open(), os.path.exist() 等文件, 文件夹操作的函数的代码都无法执行.
如Scipy.misc.imread(), numpy.load() 等
如果只是简单的读取一张图片, 或者一个文本等, 可以使用tf.gfile下的函数, 具体成员函数如下
tf.gfile.Copy(oldpath, newpath, overwrite=False) # 拷贝文件 tf.gfile.DeleteRecursively(dirname) # 递归删除目录下所有文件 tf.gfile.Exists(filename) # 文件是否存在 tf.gfile.FastGFile(name, mode='r') # 无阻塞读写文件 tf.gfile.GFile(name, mode='r') # 读写文件 tf.gfile.Glob(filename) # 列出文件夹下所有文件, 支持pattern tf.gfile.IsDirectory(dirname) # 返回dirname是否为一个目录 tf.gfile.ListDirectory(dirname) # 列出dirname下所有文件 tf.gfile.MakeDirs(dirname) # 在dirname下创建一个文件夹, 如果父目录不存在, 会自动创建父目录. 如果 文件夹已经存在, 且文件夹可写, 会返回成功 tf.gfile.MkDir(dirname) # 在dirname处创建一个文件夹 tf.gfile.Remove(filename) # 删除filename tf.gfile.Rename(oldname, newname, overwrite=False) # 重命名 tf.gfile.Stat(dirname) # 返回目录的统计数据 tf.gfile.Walk(top, inOrder=True) # 返回目录的文件树
具体的文档可以参照这里(可能需要翻墙)
如果是一批一批的读取文件, 一般会采用tf.WhoFileReader() 和 tf.train.batch() / tf.train.shuffer_batch()
接下来会重点介绍常用的 tf.gfile.Glob, tf.gfile.FastGFile, tf.WhoFileReader() 和 tf.train.shuffer_batch()
1. 获取文件列表 2. 读取文件
3. 创建batch
tf.flags可以提供了这个功能
import tensorflow as tf import os FLAGS = tf.flags.FLAGS # 前面的buckets, checkpointDir都是固定的, 不建议更改 tf.flags.DEFINE_string('buckets', 'oss://XXX', '训练图片所在文件夹') tf.flags.DEFINE_string('batch_size', '15', 'batch大小') # 获取文件列表 files = tf.gfile.Glob(os.path.join(FLAGS.buckets,'*.jpg')) # 如我想列出buckets下所有jpg文件路径1. (小规模读取时建议) tf.gfile.FastGfile()
for path in files:
file_content = tf.gfile.FastGFile(path, 'rb').read() # 一定记得使用rb读取, 不然很多情况下都会报错 image = tf.image.decode_jpeg(file_content, channels=3) # 本教程以JPG图片为例
2. (大批量读取时建议) tf.WhoFileReader()
reader = tf.WholeFileReader() # 实例化一个reader fileQueue = tf.train.string_input_producer(files) # 创建一个供reader读取的队列 file_name, file_content = reader.read(fileQueue) # 使reader从队列中读取一个文件 image = tf.image.decode_jpeg(file_content, channels=3) # 讲读取结果解码为图片 label = XXX # 这里省略处理label的过程 batch = tf.train.shuffle_batch([label, image], batch_size=FLAGS.batch_size, num_threads=4, capacity=1000 + 3 * FLAGS.batch_size, min_after_dequeue=1000) sess = tf.Session() # 创建Session tf.train.start_queue_runners(sess=sess) # 重要!!! 这个函数是启动队列, 不加这句线程会一直阻塞 labels, images = sess.run(batch) # 获取结果解释下其中重要的部分
tf.train.string_input_producer, 这个是把files转换成一个队列, 并且需要 tf.train.start_queue_runners 来启动队列
tf.train.shuffle_batch 参数解释
batch_size 批大小, 每次运行这个batch, 返回多少个数据
num_threads 运行线程数, 在PAI上4个就好
capacity 随机取文件范围, 比如你的数据集有10000个数据, 你想从5000个数据中随机取, capacity就设置成5000.
min_after_dequeue 维持队列的最小长度, 这里只要注意不要大于capacity即可
原则上来说, PAI不跨区域读取OSS是不收费的, 但是OSS的API是收费的. PAI在使用 tf.gile.Glob 的时候 会产生GET请求, 在写入 tensorboard 的时候, 也会产生PUT请求. 这两种请求都是按次收费的, 具体价格如下
当数据集有几十万图片, 通过 tf.gile.Glob 一次就需要几毛钱. 所以减少费用开支的方法就是减少GET请求次数
1. 最好的解决思路, 把所有会使用到的数据, 一并打包传到OSS, 然后使用python解压, 最后通过tensorflow读取, 这样是最节省开支的. 缺点是灵活性不强, 不过代码和训练数据分开上传, 相比一起上传提高了灵活性
通过tfrecords, 在本地, 提前把几十上百张图片通过tfrecords存下来, 这样读取的时候可以减少GET请求
. 把训练使用的图片随着代码的压缩包一起传上去, 不走OSS读取
3.使用中需要注意的
PAI没有权限读取不在数据源目录和输出目录下的文件, 所以在使用路径前, 确保他们已经在控制台右侧设置过.
另外如果需要写入文件到OSS, 可以使用 tf.gfile.fastGfile('OSS路径', 'wb').write('内容')
最后更新:2017-08-23 16:02:30