머신러닝은 주로 대량의 매개변수를 입력받아서 마지막에 손실 함수를 거쳐 출력을 내는 형태로 진행된다. 즉, 손실 함수의 각 매개변수에 대한 미분을 계산해야한다. 이런 경우 미분값을 출력에서 입력 방향으로 전파하면 한 번의 전파만으로 모든 매개변수에 대한 미분을 계산할 수 있다. 이처럼 계산이 효율적으로 이뤄지기 때문에 미분을 반대 방향으로 전파하는 방식(역전파)를 이용하는 것이다.
역전파를 구현할때 주의할 점은 순전파 시 이용한 데이터가 필요하다는 것이다. 따라서 역전파를 구현하려면 먼저 순전파를 하고, 이때 각 함수가 입력 변수의 값을 기억해두지 않으면 안된다. 그런 다음에야 각 함수의 역전파를 계산할 수 있다.
역전파에 대응하는 Variable 클래스를 구현
class Variable:
def __init__(self,data):
self.data=data
self.grad=None
Function 클래스 추가 구현
class Function:
def __call__(self,input):
x=input.data
y=self.forward(x)
output=Variable(y)
self.input=input
return output
def forward(self,x):
raise NotImplementedError()
def backward(self,gy):
raise NotImplementedError()
코드에서 보듯 __call__ 메서드에서 입력된 input을 인스턴스 변수인 self.input에 저장한다. 이렇게 해서 나중에 backward 메서드에서 함수에 입력한 변수가 필요할 때 self.input에서 가져와 사용할 수 있다.
Function을 상속한 구체적인 함수에서 역전파 구현
class Square(Function):
def forward(self,x):
y=x**2
return y
def backward(self,gy):
x=self.input.data
gx=2*x*gy
return gx
class Exp(Function):
def forward(self,x):
y=np.exp(x)
return y
def backward(self,gy):
x=self.input.data
gx=np.exp(x) * gy
return gx
이제 준비작업이 끝났다. 다음 그림에 해당하는 계산의 미분을 역전파로 계산해보겠다.
A=Square()
B=Exp()
C=Square()
x=Variable(np.array(0.5))
a=A(x)
b=B(a)
y=C(b)
y.grad=np.array(1.0)
b.grad=C.backward(y.grad)
a.grad=B.backward(b.grad)
x.grad=A.backward(a.grad)
print(x.grad)
이상이 역전파 구현이다. 제대로 동작하지만 역전파 순서에 맞춰 호출하는 코드를 우리가 일일이 작성해 넣는 건 불편할 것 같다. 이 작업을 자동화 해보자.
'Deep learning > 모델 구현' 카테고리의 다른 글
24. 가변 길이 인수(순전파) (0) | 2021.10.04 |
---|---|
23. 역전파 자동화 (0) | 2021.09.29 |
21. 수치미분 (0) | 2021.09.28 |
20. 변수와 함수 (0) | 2021.09.28 |
19. CNN(6) - CNN 시각화하기, 대표적인 CNN (0) | 2021.04.26 |