Deep learning/이론(hands on machine learning)

6. 케라스로 다층 퍼셉트론 구현(3) - 서브클래싱 API, 모델 저장과 복원, 콜백

jwjwvison 2021. 6. 2. 22:46

 시퀀셜 API와 함수형 API는 모두 선언적이다. 이 방식에는 장점이 많다. 모델을 저장하거나 복사, 공유하기 쉽다. 또한 모델의 구조를 출력하거나 분석하기 좋다. 프레임워크가 크기를 짐작하고 타입을 확인하여 에러를 일찍 발견 할 수 있다(즉 모델에 데이터가 주입되기 전에). 전체 모델이 층으로 구성된 정적 그래프이므로 디버깅하기도 쉽다. 하지만 정적이라는 것이 단점도 된다. 어떤 모델은 반복문을 포함하고 다양한 크기를 다루어야 하며 조건물을 가지는 등 여러 가지 동적인 구조를 필요로 한다. 이런 경우에 조금 더 명령형 프로그래밍 스타일이 필요하다면 서브클래싱 API가 정답이다.

 

< 서브클래싱 API로 동적 모델 만들기>

 간단히 Model 클래스를 생성자 안에서 필요한 층을 만든다. 그다음 call() 메서드 안에 수행하려는 연산을 기술한다. 예를 들어 다음 WideAndDeepModel 클래스의 인스턴스는 앞서 함수형 API로 만든 모델과 동일한 기능을 수행한다.

class WideAndDeepModel(keras.Model):
  def __init__(self,units=30,activation='relu',**kwargs):
    super().__init__(**kwargs)  #표준 매개변수를 처리한다(예를 들면, name)
    self.hidden1=keras.layers.Dense(units,activation=activation)
    self.hidden2=keras.layers.Dense(units,activation=activation)
    self.main_output=keras.layers.Dense(1)
    self.aux_output=keras.layers.Dense(1)

  def call(self,inputs):
    input_A,input_B=inputs
    hidden1=self.hidden1(input_B)
    hidden2=self.hidden2(hidden1)
    concat=keras.layers.concatenate([input_A,hidden2])
    main_output=self.main_output(concat)
    aux_output=self.aux_output(hidden2)
    return main_output,aux_output

model=WideAndDeepModel()

 이 예제는 함수형 API와 매우 비슷하지만 Input 클래스의 객체를 만들 필요가 없다. 대신 call() 메서드의 input 매개변수를 사용한다. 생성자에 있는 층 구성과 call() 메서드에 있는 정방향 계산을 분리했다. 주된 차이점은 call() 메서드 안에서 원하는 어떤 계산도 사용할 수 있다는 것이다. for,if문, 텐서플로 저수준 연산을 사용할 수 있다.

 유연성이 높아지면 그에 따른 비용이 발생한다. 모델 구조가 call() 메서드 안에 숨겨저 있기 때문에 케라스가 쉽게 이를 분석할 수 없다. 즉 모델을 저장하거나 복사할 수 없다.

 

< 모델 저장과 복원 >

 시퀀셜 API와 함수형 API를 사용하면 훈련된 케라스 모델을 저장하는 것은 다음과 같이 매우 쉽다.

 케라스는 HDF5 포맷을 사용하여 (모든 층의 하이퍼파라미터를 포함하여) 모델 구조와 층의 모든 모델 파라미터(즉 연결 가중치와 편향)를 저장한다. 또한 (하이퍼파라미터와 현재 상태를 포함하여) 옵티마이저도 저장한다. 

 

 일반적으로 하나의 파이썬 스크립트에서 모델을 훈련하고 저장한 다음 하나 이상의 스크립트(또는 웹 서비스)에서 모델을 로드하고 예측을 만드는 데 활용한다.

 

< 콜백 사용하기 >

 콜백(callback)를 사용하면 fit() 메서드에서 체크포인트를 저장할 수 있다.

 fit() 메서드의 callbacks 매개변수를 사용하여 케라스가 훈련의 시작이나 끝에 호출할 객체 리스트를 지정할 수 있다. 또는 에포크의 시작이나 끝, 각 배치 처리 전후에 호출할 수도 있다. 예를 들어 ModelCheckpoint는 훈련하는 동안 일정한 간격으로 모델의 체크포인트를 지정한다. 기본적으로 매 에포크의 끝에서 호출된다.

 모델이 향상되지 않으면 훈련이 자동으로 중지되기 때문에 에포크의 숫자를 크게 지정해도 된다. 이 경우 EarlyStopping 콜백이 훈련이 끝난 후 최상의 가중치를 복원하기 때문에 저장된 모델을 따로 복원할 필요가 없다.

 

 더 많은 제어를 원한다면 사용자 정의 콜백을 만들수 있다.