지금까지는 합수에 입출력 변수가 하나씩인 경우만 생각했다. 그러나 함수에 따라 여러 개의 변수를 입력 받기도 한다.
- Function 클래스 수정
가변 길이 입출력을 표현하기 위해 변수들을 리스트(또는 튜플)에 넣어 처리한다.
class Function:
def __call__(self,inputs):
xs=[x.data for x in inputs]
ys=self.forward(xs)
outputs=[Variable(as_array(y)) for y in ys]
for output in outputs:
output.set_creator(self) # 출력 변수에 참조자 설정
self.inputs=inputs
self.outputs=outputs # 출력도 저장
return outputs
def forward(self,xs):
raise NotImplementedError()
def backward(self,gys):
raise NotImplementedError()
- Add 클래스 구현
Add 클래스의 forward 메서드를 구현한다. 주의할 점은 인수와 반환값이 리스트(또는 튜플)여야 한다는 것이다.
class Add(Function):
def forward(self,xs):
x0,x1=xs
y=x0+x1
return (y,)
xs=[Variable(np.array(2)),Variable(np.array(3))]
f=Add()
ys=f(xs)
y=ys[0]
print(y.data)
- 함수를 사용하기 쉽게, 구현하기 쉽게 개선: 함수를 사용하기 쉽게
class Function:
def __call__(self, *inputs): #2
xs=[x.data for x in inputs]
ys=self.forward(*xs) #3
if not isinstance(ys,tuple): #4
ys=(ys,)
outputs=[Variable(as_array(y)) for y in ys]
for output in outputs:
output.set_creator(self) # 출력 변수에 참조자 설정
self.inputs=inputs
self.outputs=outputs # 출력도 저장
return outputs if len(outputs)> 1 else outputs[0] #1
def forward(self,xs):
raise NotImplementedError()
def backward(self,gys):
raise NotImplementedError()
(1) __call__의 return 부분을 보면, outputs에 원소가 하나뿐이면 리스트가 아니라 그 원소만을 반환한다. 다시 말해 함수의 반환값이 하나라면 해당 변수를 직접 돌려준다.
(2)함수를 정의할 때 인수 앞에 별표(*)를 붙였다. 이렇게 하면 리스트를 사용하는 대신 임의 개수의 인수(가변 길이 인수)를 건네 함수를 호출할 수 있다. 가변 길이 인수의 사용법은 다음 예를 보면 명확해질것이다.
이 코드에서 알 수 있듯이 함수를 '정의'할 때 인수에 별표를 붙이면 호출할 때 넘긴 인수들을 별표를 붙인 인수 하나로 모아서 받을 수 있다.
(3) 함수를 '호출'할 때 별표를 붙였는데, 이렇게 하면 리스트 언팩이 이루어진다.
(4) ys가 튜플이 아닌 경우 튜플로 변경한다. 이제 forward 메서드는 반환 원소가 하나 뿐이라면 해당 원소를 직접 반환한다.
class Add(Function):
def forward(self,x0,x1):
y=x0+x1
return y
마지막으로 Add 클래스를 파이썬 함수로 사용할 수 있는 코드를 추가하겠다.
def add(x0,x1):
return Add()(x0,x1)
x0=Variable(np.array(2))
x1=Variable(np.array(3))
y=add(x0,x1)
print(y.data)
'Deep learning > 모델 구현' 카테고리의 다른 글
26. 복잡한 계산 그래프 (0) | 2021.10.04 |
---|---|
25. 가변 길이 인수(역전파) (0) | 2021.10.04 |
23. 역전파 자동화 (0) | 2021.09.29 |
22. 역전파 (0) | 2021.09.29 |
21. 수치미분 (0) | 2021.09.28 |