tf.GradientTape API
- 자동미분을 위한 API
- context 안에서 실행된 모든 연산을 tape에 "기록"
- reverse mode differentiation을 사용해 테이프에 "기록된" 연산의 그래디언트를 계산
예시
# y = 10x
x = tf.constant(2.0)
with tf.GradientTape() as t:
t.watch(x)
y = tf.multiply(x, 10)
grad_dy_dx = t.gradient(y, x)
# x에 대한 y의 미분값
tf.Tensor(10.0, shape=(), dtype=float32)
여러 Gradient 계산시
x = tf.constant(2.0)
y = tf.constant(3.0)
with tf.GradientTape() as t:
t.watch(x)
t.watch(y)
z = tf.multiply(x, y)
grad_dz_dx = t.gradient(z, x)
grad_dz_dy = t.gradient(z, y)
print(grad_dz_dx)
print(grad_dz_dy)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-a345c641a57c> in <module>()
7
8 grad_dz_dx = t.gradient(z, x)
----> 9 grad_dz_dy = t.gradient(z, y)
10 print(grad_dz_dx)
11 print(grad_dz_dy)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/backprop.py in gradient(self, target, sources, output_gradients, unconnected_gradients)
978 """
979 if self._tape is None:
--> 980 raise RuntimeError("GradientTape.gradient can only be called once on "
981 "non-persistent tapes.")
982 if self._recording:
RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.
위에처럼 에러가 뜬다 한 Gradient Tape에서 하나만 호출할 수 있기 때문이라 한다.
해결책은 2가지가 있다.
1. Gradient tape을 두 개 쓰는 것
x = tf.constant(2.0)
y = tf.constant(3.0)
with tf.GradientTape() as t1, tf.GradientTape() as t2:
t1.watch(x)
t2.watch(y)
z = tf.multiply(x, y)
grad_dz_dx = t1.gradient(z, x)
grad_dz_dy = t2.gradient(z, y)
print(grad_dz_dx)
print(grad_dz_dy)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
2. Persistent option을 쓰는 것
x = tf.constant(2.0)
y = tf.constant(3.0)
with tf.GradientTape(persistent=True) as t:
t.watch(x)
t.watch(y)
z = tf.multiply(x, y)
grad_dz_dx = t.gradient(z, x)
grad_dz_dy = t.gradient(z, y)
print(grad_dz_dx)
print(grad_dz_dy)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)
참고 1 : https://www.tensorflow.org/tutorials/customization/autodiff?hl=ko
참고 2 : https://datascienceschool.net/view-notebook/4b286ba9c76c4b36a9218074c8dce524/
참고 3 : https://www.tensorflow.org/tutorials/customization/autodiff?hl=ko