- weakref 모듈
파이썬에서는 weakref.ref 함수를 사용하여 약한 참조를 만들 수 있다. 약한 참조란 다른 객체를 참조하되 참조 카운트는 증가시키지 않는 기능이다.
이 weakref 구조를 DeZero에도 도입한다. 먼저 Function이다.
import weakref
class Function:
def __call__(self, *inputs):
xs=[x.data for x in inputs]
ys=self.forward(*xs)
if not isinstance(ys,tuple):
ys=(ys,)
outputs=[Variable(as_array(y)) for y in ys]
self.generation=max([x.generation for x in inputs])
for output in outputs:
output.set_creator(self) # 출력 변수에 참조자 설정
self.inputs=inputs
self.outputs=[weakref.ref(output) for output in outputs] # 약한 참조
return outputs if len(outputs)> 1 else outputs[0]
def forward(self,xs):
raise NotImplementedError()
def backward(self,gys):
raise NotImplementedError()
이 변경의 여파로 다른 클래스에서 Function 클래스의 outputs를 참조하는 코드도 수정해야 한다.
이번에는 DeZero의 메모리 사용을 개선할 수 있는 구조 두가지를 도입해보겠다. 첫 번째는 역전파 시 사용하는 메모리양을 줄이는 방법으로, 불필요한 미분 결과를 보관하지 않고 즉시 삭제한다. 두 번째는 '역전파가 필요 없는 경우용 모드'를 제공한느 것이다. 이 모드에서는 불필요한 계산을 생략한다.
- 필요 없는 미분값 삭제
첫 번째로 DeZero의 역전파를 개선하겠다. 현재는 모든 변수가 미분값을 변수에 저장해두고 있다.
그러나 많은 경우, 특히 머신러닝에서는 역전파로 구하고 싶은 미분값은 말단 변수(x0,x1)뿐일 때가 대부분이다. 그러므로 중간 변수에 대해서는 미분값을 제거하는 모드를 추가하겠다.
class Variable:
def __init__(self,data):
if data is not None:
if not isinstance(data,np.ndarray):
raise TypeError('{}는 지원하지 않습니다'.format(type(data)))
self.data=data
self.grad=None
self.creator=None
self.generation=0 # 세대 수를 기록하는 변수
def set_creator(self,func):
self.creator=func
self.generation=func.generation + 1
def backward(self,retain_grad=False):
if self.grad is None:
self.grad = np.ones_like(self.data)
funcs=[]
seen_set=set()
def add_func(f):
if f not in seen_set:
funcs.append(f)
seen_set.add(f)
funcs.sort(key=lambda x: x.generation)
add_func(self.creator)
while funcs:
f=funcs.pop()
gys=[output().grad for output in f.outputs]
gxs=f.backward(*gys)
if not isinstance(gxs,tuple):
gxs=(gxs,)
for x,gx in zip(f.inputs,gxs):
if x.grad is None:
x.grad=gx
else:
x.grad=x.grad + gx
if x.creator is not None:
add_func(x.creator)
if not retain_grad:
for y in f.outputs:
y().grad=None # y는 약한 참조(weakref)
def cleargrad(self):
self.grad=None
우선 backward 인수에 retain_grad를 추가한다. 이 retain_grad가 True면 지금까지처럼 모든 변수가 미분 결과를 유지한다. 반면 retain_grad가 False면 중간 변수의 미분값을 모두 None으로 재설정한다.
- Config 클래스를 활용한 모드 전환
순전파만 할 경우를 위한 개선을 추가하겠다. 우선 두가지 모드, 즉 역전파 활성 모드와 역전파 비활성 모드를 전환하는 구조가 필요하다. 간단히 다음 Config 클래스를 이용할 것이다.
class Config:
enable_backprop=True
Config 클래스를 정의했으니 Function에서 참조하게 하여 모드를 전환할 수 있게 하겠다.
import weakref
class Function:
def __call__(self, *inputs):
xs=[x.data for x in inputs]
ys=self.forward(*xs)
if not isinstance(ys,tuple):
ys=(ys,)
outputs=[Variable(as_array(y)) for y in ys]
if Config.enable_backprop: # 역전파를 할것이면
self.generation=max([x.generation for x in inputs])
for output in outputs:
output.set_creator(self) # 출력 변수에 참조자 설정
self.inputs=inputs
self.outputs=[weakref.ref(output) for output in outputs]
return outputs if len(outputs)> 1 else outputs[0]
def forward(self,xs):
raise NotImplementedError()
def backward(self,gys):
raise NotImplementedError()
- 모드 전환
Config.enable_backprop=True x=Variable(np.ones((100,100,100))) y=square(square(square(x))) y.backward() Config.enable_backprop=False x=Variable(np.ones((100,100,100))) y=square(square(square(x)))
- with 문을 활용한 모드 전환
파이썬에는 with라고 하는, 후처리를 자동으로 수행하고자 할 때 사용할 수 있는 구문이 있다.import contextlib @contextlib.contextmanager def using_config(name,value): old_value=getattr(Config,name) setattr(Config,name,value) try: yield finally: setattr(Config,name,old_value)
with using_config('enable_backprop',False):
x=Variable(np.array(2.0))
y=square(x)
이와 같이 with문 안에서만 역전파 비활성 모드가 된다.
def no_grad():
return using_config('enable_backprop',False)
with no_grad(): # 기울기가 필요 없을때는 이 함수를 호출하면 된다
x=Variable(np.array(2.0))
y=square(x)
'Deep learning > 모델 구현' 카테고리의 다른 글
29. 연산자 오버로드 (0) | 2021.10.10 |
---|---|
28. 변수 사용성 개선 (0) | 2021.10.10 |
26. 복잡한 계산 그래프 (0) | 2021.10.04 |
25. 가변 길이 인수(역전파) (0) | 2021.10.04 |
24. 가변 길이 인수(순전파) (0) | 2021.10.04 |