PyTorch

[PyTorch] Weight Initialization

슈넌 2022. 6. 30. 19:57

Weight Initialization은 Local Minimum을 찾기 위한 시작점을 정해주는 방법이다. PyTorch를 통해 모델을 처음 만들게 되면 초기 Weight는 Random Initialization으로 되어 있다. 이는 Local Minimum을 찾아내기 위한 적절한 Initialization 방법이 아니다. 각 상황에 따라, 모델에 따라 적절한 Initialization 방법은 다르지만, 여기서 소개할 방법은 크게 두가지이다.

1. Xavier Initialization

Xavier Initialization은 Xavier Glorot와 Yoshua Bengio가 만든 방법으로 인풋 채널과 아웃풋 채널에 따라 Gaussian Distribution 또는 Uniform Distribution을 이용하는 방법이다.

 

1.1 Xavier Uniform Initialization

u는 Uniform Distribution을 뜻한다. 이때 범위는 a에 의해 결정된다. a 식에서 gain은 scale factor이고, fan_in과 fan_out에 따라서 범위가 달라지게 된다. fan_in은 receptive field와 input channel의 곱이고, fan_out은 output channel의 곱이다.

 

코드: torch.nn.init.xavier_uniform_(tensor, gain=1.0)

 

1.2 Xavier Normal Initialization

N은 Gaussian Distribution을 의미한다. 이때 std는 표준편차, mean은 0의 값을 갖는다. gain은 scale factor로 사용되며 fan_in과 fan_out은 앞에서 설명한 것과 동일한다.

 

코드: torch.nn.init.xavier_normal_(tensor, gain=1.0)

 

2. He Initialization

Kaiming He는 유명하기 때문에 많이들 들어봤을 것이다. He Initialization 또한 그의 이름을 따서 만든 것이다. 이는 ReLU Activation Function을 사용할 때 발생하는 0으로 수렴하는 문제를 해결하기 위해서 만든 것이다. 그렇기 때문에 ReLU를 사용하는 상황이라면 He Initialization을 추천한다.

 

2.1 He Uniform Initialization

u는 Uniform Distribution이며, bound에 의해 영역이 정해진다. gain은 앞서 말한대로 scale factor이며, fan_mode는 fan_in 중에 fan_out을 고르는 것이다. default 값으로는 fan_in을 사용하여 input channel을 기준으로 하도록 한다.

 

코드: torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

 

2.1 He Normal Initialization

N은 Gaussian Distribution을 의미하고, std를 계산할 때는 scale factor인 gain과 앞서 설명한 fan_mode를 사용한다.

 

코드: torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

 

3. 모델 적용 코드

모델에 적용을 하고 싶을때, 모든 layer마다 적용할 수 없기에 다음과 같은 함수를 이용한다. bias의 경우에는 0으로 Initialization을 해주거나 0.01과 같은 매우 작은 값으로 Initialization을 해준다.

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
model.apply(init_weights)

위의 예시는 He Uniform Initialization을 적용한 예시이다.  Batch Normalization에는 weight를 1로, bias를 0으로 초기화 해주었다.

 

Weight Initialization은 성능을 높이기 위한 가장 기본적인 방법이므로 제대로 익힐 필요가 있다.

 

참고: PyTorch Document