M.L (p.161)
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import numpy as np class MulLayer: def __init__(self): self.x = None self.y = None def forward(self,x,y): self.x = x self.y = y out = x * y return out def backward(self,dout): dx = dout * self.y dy = dout * self.x return dx,dy class AddLayer: def __init__(self): pass def forward(self,x,y): out = x + y return out def backward(self,dout): dx = dout dy = dout return dx,dy |
backpropagation(역전파)에 덧셈과 곱셈 노드에 대한 코드이다.
역전파는 계산그래프를 거꾸로 돌아가며 각 weight와 parameter가 Y(output value)에 얼마나 영향을 미치는지(미분을 통해) chain rule 이라는 성질을 이용해 쉽게 구할 수 있다.
덧셈에 대한 역전파로 Z = x + y 를 각 변수에 맞춰 미분하면
1이 됨을 쉽게 알 수 있다. 결국 노드에서 흘러들어 오는 값을 그대로 가지게 되는 것이다.
다음은 곱셈에 대한 역전파로 간단히 Z = xy 식을 생각해보면 각 변수에 맞춰 편미분을 하면
값이 반대를 가지는 것을 알 수 있다. 즉 서로 바꾼 값이라 이해하면 된다.
이처럼 역전파를 하면 쉽게 식이 쓰이기 때문에 코드를 간편하게 만들 수 있다.