Tensorflowで学習するデータファイルの読み込みを行う方法

通常のPythonのコードでもファイルを読み込むことはできますが、DeepLearningにおいて入力〜学習までの一連の処理をTensorflowのグラフで定義することで、GPU上で並列処理を行うことができます。

全部のデータを一通り読み込んでから学習を開始するのではなく、データの読み込み、オーグメンテーション、学習を同時並行で進めたほうが効率的にリソースを活かすことができますよね。

本記事では、大規模な学習をする際に活躍するtf.train.string_input_producerを使った方法を紹介します。tf.train.string_input_producerはファイル入力のキューを作成することができ、Tensorflowで定義した学習ネットワークの入力に指定することで自動的に流し込んでくれます。

tf.train.string_input_producerのAPIリファレンス

入力として与えたString一覧をキューの文字列として出力する。

tf.train.string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)

string_tensor: 1次元のString型のテンソル。学習するファイルの名称の一覧を入力に与える。

①入力ファイルの一覧をcsvで読み込む方法

入力として与えるファイルの名称一覧を事前にcsvなどに書き出してある場合、次のように定義すると良いでしょう。

path_queue = tf.train.string_input_producer([csv_filename])
reader = tf.TextLineReader()
key, value = reader.read(path_queue)
in1, in2, output = tf.decode_csv(value, record_defaults=[[0.], [0.], [0.]])
…

②ディレクトリに存在するファイル一覧を取得する方法

対象のディレクトリに存在するファイルを入力として使用したい場合は、次のように記述します。

…
input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
…
path_queue = tf.train.string_input_producer(input_paths)
reader = tf.WholeFileReader()
paths, value = reader.read(path_queue)
raw_input = decode(value)

シェアする

フォローする