Deep learning/모델 구현

26. 복잡한 계산 그래프

jwjwvison 2021. 10. 4. 16:38

 지금의 DeZero는 다음 그림과 같은 계산 그래프도 만들 수 있다.

 그러나 지금의 DeZero는 이런 계산의 미분은 제대로 계산하지 못한다. 더 정확하게는 이런 복잡한 연결의 역전파를 제대로 할 수 없다.

함수 우선 순위

 

  • 세대 추가
    class Variable:
        def __init__(self,data):
            if data is not None:
                if not isinstance(data,np.ndarray):
                    raise TypeError('{}는 지원하지 않습니다'.format(type(data)))
            self.data=data
            self.grad=None
            self.creator=None
            self.generation=0 # 세대 수를 기록하는 변수
            
        def set_creator(self,func):
            self.creator=func
            self.generation=func.generation + 1
            
        def backward(self):
            if self.grad is None:
                self.grad = np.ones_like(self.data)
            
            funcs=[self.creator]
            
            while funcs:
                f=funcs.pop()
                gys=[output.grad for output in f.outputs]
                gxs=f.backward(*gys)
                if not isinstance(gxs,tuple):
                    gxs=(gxs,)
                    
                for x,gx in zip(f.inputs,gxs):
                    if x.grad is None:
                        x.grad=gx
                    else:
                        x.grad=x.grad + gx
                 
                    if x.creator is not None:
                        funcs.append(x.creator)
                    
        
        def cleargrad(self):
            self.grad=None​

 Variable 클래스는 generation을 0으로 초기화한다. 그리고 set_creator 메서드가 호출될 때 부모 함수의 새대보다 1만큼 큰 값을 설정한다.

 

 다음 차례는 Functoin클래스이다. Function클래스의 generation은 입력 변수와 같은 값으로 설정한다.

 입력 변수가 둘 이상이라면 가장 큰 generation의 수를 선택한다.

class Function:
    def __call__(self, *inputs):
        xs=[x.data for x in inputs]
        
        ys=self.forward(*xs)
        if not isinstance(ys,tuple):
            ys=(ys,)
        outputs=[Variable(as_array(y)) for y in ys]
        
        self.generation=max([x.generation for x in inputs])
        for output in outputs:
            output.set_creator(self) # 출력 변수에 참조자 설정
        self.inputs=inputs
        self.outputs=outputs # 출력도 저장  [6]
        return outputs if len(outputs)> 1 else outputs[0]
    
    def forward(self,xs):
        raise NotImplementedError()
        
    def backward(self,gys):
        raise NotImplementedError()

 

  • 세대 순으로 꺼내기

 이렇게 세대가 설정되어 있으면 역전파 때 함수를 올바른 순서로 꺼낼 수 있다. 예를 들어 함수 A보다 세대가 큰 B와 C를 먼저 꺼내게 된다.

 

  • Variable 클래스의 backward
    class Variable:
        def __init__(self,data):
            if data is not None:
                if not isinstance(data,np.ndarray):
                    raise TypeError('{}는 지원하지 않습니다'.format(type(data)))
            self.data=data
            self.grad=None
            self.creator=None
            self.generation=0 # 세대 수를 기록하는 변수
            
        def set_creator(self,func):
            self.creator=func
            self.generation=func.generation + 1
            
        def backward(self):
            if self.grad is None:
                self.grad = np.ones_like(self.data)
            
            funcs=[]
            seen_set=set()
            
            def add_func(f):
                if f not in seen_set:
                    funcs.append(f)
                    seen_set.add(f)
                    funcs.sort(key=lambda x: x.generation)
                    
            add_func(self.creator)
            
            while funcs:
                f=funcs.pop()
                gys=[output.grad for output in f.outputs]
                gxs=f.backward(*gys)
                if not isinstance(gxs,tuple):
                    gxs=(gxs,)
                    
                for x,gx in zip(f.inputs,gxs):
                    if x.grad is None:
                        x.grad=gx
                    else:
                        x.grad=x.grad + gx
                 
                    if x.creator is not None:
                        add_func(x.creator)
                    
        
        def cleargrad(self):
            self.grad=None​

 

'Deep learning > 모델 구현' 카테고리의 다른 글

28. 변수 사용성 개선  (0) 2021.10.10
27. 메모리 관리  (0) 2021.10.04
25. 가변 길이 인수(역전파)  (0) 2021.10.04
24. 가변 길이 인수(순전파)  (0) 2021.10.04
23. 역전파 자동화  (0) 2021.09.29