Notice
Recent Posts
Recent Comments
Link
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Tags
more
Archives
Today
Total
관리 메뉴

ufris

[TF2] Output Class 개수가 다를 때 pretrained model transfer learning 본문

딥러닝

[TF2] Output Class 개수가 다를 때 pretrained model transfer learning

ufris 2021. 7. 30. 17:38

기존에 학습된 pretrained model의 weight을 사용해서 새로운 task를 학습하는 모델로 transfer learning을 진행할 때 

 

pretrained model과 transfer learning을 진행하는 model간의 class 개수가 달라 weight이 정확히 매칭되지 않는다 문제가

 

발생합니다

 

이 문제를 해결하는 방법은 매치되는 weight만 적용시키는 것인데 pretrained model의 weight을 저장할 때

save_format을 tf가 아닌 h5형식으로 저장하는 것 입니다(model.save_weight(save_format='h5')

 

기존에 model.save_weight을 통해 weight을 저장하면 save_format이 tf로 되어 있는데 이 상태로 class 개수가 다른 모델에

 

model.load_weight을 진행하면 output shape가 달라 weight을 적용할 수 없다는 에러 납니다

 

model.load_weight(by_name=True, skip_mismatch=True)를 통해 layer 이름이 같은 것만 weight을 매칭하도록 하고

 

매칭이 되지 않는 weight은 skip하도록 설정을 하면

 

-> Weights may only be loaded based on topology into Models when loading TensorFlow formatted weights (got by_name=True to load_weights).

 

이런 에러가 납니다

 

이런 에러를 해결하기 위해서는 pretrained model의 weight을 저장할 때 model.save_weight(save_format='h5')로

 

진행하고 transfer learning을 위해 model.load_weight(by_name=True, skip_mismatch=True)을 진행하면

 

매칭되는 weight만 적용되서 에러 없이 학습이 잘 되는 것을 볼 수 있습니다