ufris
[TF2] Output Class 개수가 다를 때 pretrained model transfer learning 본문
기존에 학습된 pretrained model의 weight을 사용해서 새로운 task를 학습하는 모델로 transfer learning을 진행할 때
pretrained model과 transfer learning을 진행하는 model간의 class 개수가 달라 weight이 정확히 매칭되지 않는다 문제가
발생합니다
이 문제를 해결하는 방법은 매치되는 weight만 적용시키는 것인데 pretrained model의 weight을 저장할 때
save_format을 tf가 아닌 h5형식으로 저장하는 것 입니다(model.save_weight(save_format='h5')
기존에 model.save_weight을 통해 weight을 저장하면 save_format이 tf로 되어 있는데 이 상태로 class 개수가 다른 모델에
model.load_weight을 진행하면 output shape가 달라 weight을 적용할 수 없다는 에러 납니다
model.load_weight(by_name=True, skip_mismatch=True)를 통해 layer 이름이 같은 것만 weight을 매칭하도록 하고
매칭이 되지 않는 weight은 skip하도록 설정을 하면
-> Weights may only be loaded based on topology into Models when loading TensorFlow formatted weights (got by_name=True to load_weights).
이런 에러가 납니다
이런 에러를 해결하기 위해서는 pretrained model의 weight을 저장할 때 model.save_weight(save_format='h5')로
진행하고 transfer learning을 위해 model.load_weight(by_name=True, skip_mismatch=True)을 진행하면
매칭되는 weight만 적용되서 에러 없이 학습이 잘 되는 것을 볼 수 있습니다
'딥러닝' 카테고리의 다른 글
UMA node read from SysFS had negative value (-1) (0) | 2022.01.07 |
---|---|
Loss와 Accuracy는 항상 반비례 관계인가 (0) | 2021.07.15 |
[Cuda] GTX 3090 사용 시 tensorflow graph 생성이 느린 문제 (0) | 2021.01.20 |
Softmax cross entropy 구현(tensorflow) (0) | 2020.11.19 |
CARPE DIEM, SEIZE THE SAMPLES UNCERTAIN “ATTHE MOMENT” FOR ADAPTIVE BATCH SELECTION (0) | 2020.10.22 |