Deep learning/모델 구현

33. 형상 변환 함수, 합계 함수

jwjwvison 2021. 10. 16. 22:17
  • reshape 함수 구현

class Reshape(Function):
    def __init__ (self,shape):
        self.shape=shape

    def forward(self,x):
        self.x_shape=x.shape
        y=x.reshape(self.shape)
        return y
    
    def backward(self,gy):
        return reshape(gy,self.x_shape)

def reshape(x,shape):
    if x.shape==shape:
        return as_variable(x)

    return Reshape(shape)(x)

 여기서 reshape(x,shape) 의 x는 ndarray 인스턴스 또는 Variable 인스턴스 중 하나라고 가정한다.

 

  • 행렬의 전치

class Transpose(Function):
    def forward(self,x):
        y=np.transpose(x)
        return y

    def backward(self,gy):
        gx=transpose(gy)
        return gx
    
def transpose(x):
    return Transpose()(x)

 두 번째 T에는 @property 데코레이터가 붙어 있는데, '인스턴스 변수'로 사용할 수 있게 해주는 데코레이터이다.

 

  • sum 함수의 역전파

  • sum 함수 구현

DeZero의 sum 함수 역전파에서는 입력 변수의 형상과 같아지도록 기울기의 원소를 복사한다. 그런데 역전파에서는 Variable 인스턴스를 사용하므로 복사 작업도 DeZero 함수로 해야한다.

class Sum(Function):
    def __init__(self,axis,keepdims):
        self.axis=axis
        self.keepdims=keepdims

    def forward(self,x):
        self.x_shape=x.shape
        y=x.sum(axis=self.axis,keepdims=self.keepdims)
        return y
    
    def backward(self,gy):
        gy=utils.reshape_sum_backward(gy,self.x_shape,self.axis,self.keepdims)
        gx=broadcast_to(gy,self.x_shape)
        return gx

    
def sum(x,axis=None,keepdims=False):
    return Sum(axis,keepdims)(x)

 

  • 브로드캐스트 함수
    class BroadcastTo(Function):
        def __init__(self,shape):
            self.shape=shape
    
        def forward(self,x):
            self.x_shape=x.shape
            y=np.broadcast_to(x,self.shape)
            return y
    
        def backward(self,gy):
            gx=sum_to(gy,self.x_shape)
            return gx
    
    def broadcast_to(x,shape):
        if x.shape==shape:
            return as_variable(x)
        return BroadcastTo(shape)(x)

 

'Deep learning > 모델 구현' 카테고리의 다른 글

35. 매개변수를 모아두는 계층  (0) 2021.10.17
34. 신경망  (0) 2021.10.17
32. 고차 미분  (0) 2021.10.14
31. 함수 최적화(중요)  (0) 2021.10.12
30. 테일러 급수 미분  (0) 2021.10.12