Deep learning/모델 구현

22. 역전파

jwjwvison 2021. 9. 29. 00:16

 머신러닝은 주로 대량의 매개변수를 입력받아서 마지막에 손실 함수를 거쳐 출력을 내는 형태로 진행된다. 즉, 손실 함수의 각 매개변수에 대한 미분을 계산해야한다. 이런 경우 미분값을 출력에서 입력 방향으로 전파하면 한 번의 전파만으로 모든 매개변수에 대한 미분을 계산할 수 있다. 이처럼 계산이 효율적으로 이뤄지기 때문에 미분을 반대 방향으로 전파하는 방식(역전파)를 이용하는 것이다.

 역전파를 구현할때 주의할 점은 순전파 시 이용한 데이터가 필요하다는 것이다. 따라서 역전파를 구현하려면 먼저 순전파를 하고, 이때 각 함수가 입력 변수의 값을 기억해두지 않으면 안된다. 그런 다음에야 각 함수의 역전파를 계산할 수 있다.

 

역전파에 대응하는 Variable 클래스를 구현

class Variable:
    def __init__(self,data):
        self.data=data
        self.grad=None

 

Function 클래스 추가 구현

class Function:
    def __call__(self,input):
        x=input.data
        y=self.forward(x)
        output=Variable(y)
        self.input=input
        return output
    
    def forward(self,x):
        raise NotImplementedError()
        
    def backward(self,gy):
        raise NotImplementedError()

 코드에서 보듯 __call__ 메서드에서 입력된 input을 인스턴스 변수인 self.input에 저장한다. 이렇게 해서 나중에 backward 메서드에서 함수에 입력한 변수가 필요할 때 self.input에서 가져와 사용할 수 있다.

 

 Function을 상속한 구체적인 함수에서 역전파 구현

class Square(Function):
    def forward(self,x):
        y=x**2
        return y
    
    def backward(self,gy):
        x=self.input.data
        gx=2*x*gy
        return gx
        
        
class Exp(Function):
    def forward(self,x):
        y=np.exp(x)
        return y
    
    def backward(self,gy):
        x=self.input.data
        gx=np.exp(x) * gy
        return gx

 

 이제 준비작업이 끝났다. 다음 그림에 해당하는 계산의 미분을 역전파로 계산해보겠다.

 

A=Square()
B=Exp()
C=Square()

x=Variable(np.array(0.5))
a=A(x)
b=B(a)
y=C(b)

y.grad=np.array(1.0)
b.grad=C.backward(y.grad)
a.grad=B.backward(b.grad)
x.grad=A.backward(a.grad)
print(x.grad)

 

 이상이 역전파 구현이다. 제대로 동작하지만 역전파 순서에 맞춰 호출하는 코드를 우리가 일일이 작성해 넣는 건 불편할 것 같다. 이 작업을 자동화 해보자.

'Deep learning > 모델 구현' 카테고리의 다른 글

24. 가변 길이 인수(순전파)  (0) 2021.10.04
23. 역전파 자동화  (0) 2021.09.29
21. 수치미분  (0) 2021.09.28
20. 변수와 함수  (0) 2021.09.28
19. CNN(6) - CNN 시각화하기, 대표적인 CNN  (0) 2021.04.26