- 가변 길이 인수에 대응한 Add 클래스의 역전파
덧셈의 역전파는 출력 쪽에서 전해지는 미분값에 1을 곱한 값이 입력 변수(x0,x1)의 미분이다. 즉, 상류에서 흘러오는 미분값을 그대로 흘려보내는 것이 덧셈의 역전파이다.
class Add(Function):
def forward(self,x0,x1):
y=x0+x1
return y
def backward(self,gy):
return gy,gy
- Variable 클래스 수정
class Variable: def __init__(self,data): if data is not None: if not isinstance(data,np.ndarray): raise TypeError('{}는 지원하지 않습니다'.format(type(data))) self.data=data self.grad=None self.creator=None def set_creator(self,func): self.creator=func def backward(self): if self.grad is None: self.grad = np.ones_like(self.data) funcs=[self.creator] while funcs: f=funcs.pop() gys=[output.grad for output in f.outputs] gxs=f.backward(*gys) if not isinstance(gxs,tuple): gxs=(gxs,) for x,gx in zip(f.inputs,gxs): x.grad=gx if x.creator is not None: funcs.append(x.creator)
맨 밑에 for문에서는 역전파로 전파되는 미분값을 Variable 인스턴스 변수 grad에 저장해둔다. 여기세어 gxs와 f.inputs의 각 원소는 서로 대응 관계에 있다. 더 정확히 말하면 i번째 원소에 대해 f.inputs[i]의 미분값은 gxs[i]에 대응한다.
- Square 클래스 구현
class Square(Function): def forward(self,x): y=x**2 return y def backward(self,gy): x=self.inputs[0].data gx=2*x*gy return gx
Function 클래스의 인스턴스 변수 이름이 단수형인 input에서 복수형인 inputs로 변경되었으니 바뀐 변수에서 입력 변수 x를 가져오도록 코드를 수정해주면 된다.
다음은 z=x*2 + y*2을 계산하는 코드이다.
x=Variable(np.array(2.0))
y=Variable(np.array(3.0))
z=add(square(x),square(y))
z.backward()
print(z.data)
print(x.grad)
print(y.grad)
- 같은 변수 반복 사용
현재의 DeZero 에는 문제가 있다. 같은 변수를 반복해서 사용할 경우 의도대로 동작하지 않을 수 있다는 문제이다.
제대로 계산한마면 미분값은 2가 나와야 한다(y=2x). 이를 해결하기 위해 다음과 같이 Variable클래스를 수정해주어야 한다.
class Variable:
def __init__(self,data):
if data is not None:
if not isinstance(data,np.ndarray):
raise TypeError('{}는 지원하지 않습니다'.format(type(data)))
self.data=data
self.grad=None
self.creator=None
def set_creator(self,func):
self.creator=func
def backward(self):
if self.grad is None:
self.grad = np.ones_like(self.data)
funcs=[self.creator]
while funcs:
f=funcs.pop()
gys=[output.grad for output in f.outputs]
gxs=f.backward(*gys)
if not isinstance(gxs,tuple):
gxs=(gxs,)
for x,gx in zip(f.inputs,gxs):
if x.grad is None:
x.grad=gx
else:
x.grad=x.grad + gx
if x.creator is not None:
funcs.append(x.creator)
이와 같이 미분값(grad)을 처음 설정하는 경우에는 지금까지와 똑같이 출력 쪽에서 전해지는 미분값을 그대로 대입한다. 그리고 다음번부터는 전달된 미분값을 더해주도록 수정한다.
x=Variable(np.array(3.0))
y=add(x,x)
y.backward()
print(x.grad)
- 미분값 재설정
위에서 역전파 시 미분값을 더해주도록 코드를 수정했다. 그런데 이 번경으로 인해 새로운 주의사항이 튀어나온다. 바로 같은 변수를 사용하여 다른 계산을 할 경우 계산이 꼬이는 문제이다.
두번째 x의 미분갑셍 첫 번째 미분 값이 더해지고 5.0이라는 잘못된 값을 돌려준다(2+1+1+1). 3.0이 되어야 한다.
x=Variable(np.array(3.0))
y=add(x,x)
y.backward()
print(x.grad)
x.cleargrad()
y=add(add(x,x),x)
y.backward()
print(x.grad)
'Deep learning > 모델 구현' 카테고리의 다른 글
27. 메모리 관리 (0) | 2021.10.04 |
---|---|
26. 복잡한 계산 그래프 (0) | 2021.10.04 |
24. 가변 길이 인수(순전파) (0) | 2021.10.04 |
23. 역전파 자동화 (0) | 2021.09.29 |
22. 역전파 (0) | 2021.09.29 |