Deep learning/모델 구현

32. 고차 미분

jwjwvison 2021. 10. 14. 21:28

 현재의 DeZero는 미분을 자동으로 계산할 수 있지만 1차 미분 한정이다. 그래서 이번 단계에서는 2차 미분도 자동으로 계산할 수 있도록, 나아가 3차 미분, 4차 미분... 형태의 모든 고차 미분까지 자동으로 계산할 수 있도록 DeZero를 확장할 것이다.

 

  • 역전파 계산

 순전파와 마찬가지로 역전파에도 구체적인 계산 로직이 있다. 하지만 현재의 DeZero는 계산과 관련한 아무런 계산 그래프도 만들지 않는다. 왜냐하면 이 계산에서는 ndarray 인스턴스가 사용되기 때문이다. 만약 역전파를 계산할 때도 '연결'이 만들어진다면 고차 미분을 자동으로 계산할 수 있게 된다.

 위 그림은 sin함수의 미분을 구하기 위한 계산 그래프이다. 만약 위와 같은 계산 그래프가 있다면 gx.backward()를 호출하여 gx의 x에 대한 미분을 계산할 수 있다. 원래 gx는 y=sin(x)의 미분이기 때문에 gx.backward()를 호출함으로써 x에 대한 미분이 한 번 더 이루어진다. 즉 이것이 2차 미분에 해당한다.

 

  • 역전파로 계산 그래프 만들기

 함수의 backward 메서드에서 ndarray 인스턴스가 아닌 Variable 인스턴스를 사용하면 계산의 '연결'을 만들 수 있다.

 이렇게 변경하면 y=sin(x)를 다음 그림과 같은 계산 그래프로 표현할 수 있다.

 위 그림은 y.backward()를 호출함으로써 '새로' 만들어지는 계산 그래프이다. 위와 같은 계산 그래프가 있다면 gx.backward()를 호출함으로써 y의 x에 대한 2차 미분이 이루어진다.

 

  • Variable 클래스 수정
        def backward(self,retain_grad=False,create_graph=False):
            if self.grad is None:
                self.grad = Variable(np.ones_like(self.data))​

 self.grad가 Variable 인스턴스를 담게 된다.

 

  • 함수들의 backward 메서드 수정

 수정 후에는 Mul 클래스에서 Variable 인스턴스를 그대로 사용한다. 여기서 중요한 점은 역전파를 계산하는 gy * x1 코드이다. gy와 x1은 Variable 인스턴스이기 때문에 gy * x1 이 실행되는 뒤편에선는 Mul 클래스의 순전파가 호출된다. 그때 Function.__call__이 호출되고, 그 안에서 계산 그래프가 만들어진다.

 

  • 2차 미분 계산하기

y=x^4 - 2x^2 이라는 수식의 2차 미분을 계산해보자.

def f(x):
    y=x**4 - 2 * x ** 2
    return y


x=Variable(np.array(2.0))
y=f(x)
y.backward(create_graph=True)
print(x.grad)

gx=x.grad
x.cleargrad()
gx.backward()
print(x.grad)

 

'Deep learning > 모델 구현' 카테고리의 다른 글

34. 신경망  (0) 2021.10.17
33. 형상 변환 함수, 합계 함수  (0) 2021.10.16
31. 함수 최적화(중요)  (0) 2021.10.12
30. 테일러 급수 미분  (0) 2021.10.12
29. 연산자 오버로드  (0) 2021.10.10