- Mul 클래스 구현
Mul 클래스의 코드는 다음과 같다.
class Mul(Function):
def forward(self,x0,x1):
y=x0 * x1
return y
def backward(self,gy):
x0,x1=self.inputs[0].data,self.inputs[1].data
return gy * x1, gy * x0
def mul(x0,x1):
return Mul()(x0,x1)
하지만 이렇게 사용하면 상당히 불편할 것 같다. 연산자 * 를 사용할 수 있게끔 바꿔보겠다.
- 연산자 오버로드
def __mul__(self,other):
return mul(self,other)
Variable 클래스에 이 __mul__ 메서드를 추가한다. 이제부터 *를 사용하면 __mul__ 메서드가 다시 불리고, 다시 그 안의 함수 mul 함수가 불리게 된다.
이와 똑같은 작업을 다음 코드처럼 간단히 처리하는 방법도 있다.
Variable.__mul__=mul
Variable.__add__=add
- ndarray와 함께 사용하기
def as_variable(obj): if isinstance(obj,Variable): return obj return Variable(obj)
이 함수는 인수 obj가 Variable 인스턴스 또는 ndarray 인스턴스라고 가정한다.
그럼 Function 클래스의 __call__ 메서드가 as_variable 함수를 이용하도록 다음 부분을 수정한다.
class Function:
def __call__(self, *inputs):
inputs=[as_variable(x) for x in inputs] #Variable 인스턴스로 변환
- float, int 와 함께 사용하기
class Add(Function): def forward(self,x0,x1): y=x0+x1 return y def backward(self,gy): return gy,gy def add(x0,x1): x1=as_array(x1) return Add()(x0,x1)
Add 함수에 as_array를 추가해 준다. as_array 함수는 x1이 float이나 int인 경우 ndarray 인스턴스로 변환된다.
그리고 다음과 같은 코드를 통해 문제점들을 해결한다.
Variable.__add__=add
Variable.__radd__=add
Variable.__mul__=mul
Variable.__rmul__=mul
__array_priority__=200
- 또다른 연산자들
class Neg(Function): def forward(self,x): return -x def backward(self,gy): return -gy def neg(x): return Neg()(x) class Sub(Function): def forward(self,x0,x1): y=x0-x1 return y def backward(self,gy): return gy,-gy def sub(x0,x1): x1=as_array(x1) return Sub()(x0,x1) def rsub(x0,x1): x1=as_array(x1) return Sub()(x1,x0) class Div(Function): def forward(self,x0,x1): y=x0/x1 return y def backward(self,gy): x0,x1=self.inputs gx0=gy/x1 gx1=gy * (-x0 / x1 ** 2) return gx0,gx1 def div(x0,x1): x1=as_array(x1) return Div()(x0,x1) def rdiv(x0,x1): x1=as_array(x1) return Div()(x1,x0) class Pow(Function): def __init__(self,c): self.c=c def forward(self,x): y=x ** self.c return y def backward(self,gy): x,=self.inputs c=self.c gx=c * x ** (c-1) * gy return gx def pow(x,c): return Pow(c)(x)
'Deep learning > 모델 구현' 카테고리의 다른 글
31. 함수 최적화(중요) (0) | 2021.10.12 |
---|---|
30. 테일러 급수 미분 (0) | 2021.10.12 |
28. 변수 사용성 개선 (0) | 2021.10.10 |
27. 메모리 관리 (0) | 2021.10.04 |
26. 복잡한 계산 그래프 (0) | 2021.10.04 |