Tensorflowのpy_funcの使い方

Tensorflowでは基本的な計算や処理を行う関数はライブラリに用意されており、基本的には不自由なくグラフ定義ができるとは思います。しかし、NumpyやScipyの関数に対応するものがTensorflowになく、どうしても直接利用しなければならないという場面があります。そこで活躍するのが、tf.py_funcになります。

簡単に言うと、Pythonで定義された関数をtensorflowでそのまま使えるようにラップする役割を果たします。

tf.py_funcのAPIリファレンス

tf.py_func(
    func,
    inp,
    Tout,
    stateful=True,
    name=None
)
  • func: Pythonで定義した関数。テンソルに変換できるように、ndarrayもしくはそのリストを返り値としてもつ必要がある。inpと入力を対応させること。
  • inp: 入力として与えるTensorオブジェクトのリスト。
  • Tout: Tensorflowのデータ型のタプル。funcが返すデータ型を指定する。
  • stateful: ステートフルorステートレスを指定する。ステートレスの場合、同じ入力対しては全く同じ値を出力することになる。
  • name: 処理の名称を指定する。(オプション)

用法

下記にコーディング例を示します。Part1で定義したtensorグラフをinputsに入れます。inputを入力としてpythonで定義した関数funcを実行し、出力結果をtensorグラフとしてoutputに出力します。Part2では引き続きoutputの結果をもとにグラフを定義すれば、sess.runを一度実行するだけでPart1~Part2一連のグラフ処理を実行できます。

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)

Py_funcが利用できない場合、Part1の結果を一度sess.runで出力してからPythonの関数に入力として与える必要があります。さらに、その結果をPart2にわたす際はPlaceholderを定義して実行時に与えなければなりません。以下のコードのようにグラフを2つに分割した処理となります。

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

Py_funcを利用するメリット・デメリット

上記で述べたとおり、グラフの中でsess.runを呼び出さなくてもPythonの関数をグラフに組み込んで利用することができるというメリットがあります。

一方で、マルチスレッディングに対応させるのが困難であるため処理速度のボトルネックとなってしまう懸念があります。

シェアする

フォローする