s t u d y . . ๐Ÿง/AI ์•ค ML ์•ค DL

[Transfer Learning] ์ „์ดํ•™์Šต ๊ฐœ๋…

H J 2022. 10. 8. 04:07

Transfer Learning

ํ•™์Šต ๋ฐ์ดํ„ฐ๊ฐ€ ๋ถ€์กฑํ•œ ๋ถ„์•ผ์˜ ๋ชจ๋ธ ๊ตฌ์ถ•์„ ์œ„ํ•ด ๋ฐ์ดํ„ฐ๊ฐ€ ํ’๋ถ€ํ•œ ๋ถ„์•ผ์—์„œ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ์žฌ์‚ฌ์šฉํ•˜๋Š” ํ•™์Šต ๊ธฐ๋ฒ•

Imagenet(๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹) ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด์„œ ์‚ฌ์ „ํ•™์Šต๋œ(pre-trained) ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง€๊ณ  ์™€์„œ ์šฐ๋ฆฌ๊ฐ€ ํ•ด๊ฒฐํ•˜๊ณ ์ž ํ•˜๋Š” ๊ณผ์ œ์— ๋งž๊ฒŒ ์žฌ๋ณด์ •ํ•ด์„œ ์‚ฌ์šฉ

๋น„๊ต์  ์ ์€ ์ˆ˜์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ง€๊ณ ๋„ ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ๊ณผ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ›ˆ๋ จ

 

→ pre-trained model์˜ weights๋ฅผ ์•ฝ๊ฐ„์”ฉ ๋ณ€ํ™”์‹œ์ผœ ์ ์€ ๋ฐ์ดํ„ฐ์…‹์—์„œ task์— ๋งž๊ฒŒ ์žฌ์‚ฌ์šฉ

→ pre-trained model์˜ classifier๋Š” ์‚ญ์ œํ•˜๊ณ  ๋ชฉ์ ์— ๋งž๋Š” ์ƒˆ๋กœ์šด classifier ์ถ”๊ฐ€

⇒ ์ƒˆ๋กญ๊ฒŒ ๋งŒ๋“ค์–ด์ง„ ๋ชจ๋ธ fine tuning ์ง„ํ–‰ (strategy 3๊ฐœ ์ค‘ 1๊ฐœ ์„ ํƒํ•ด์„œ ์ง„ํ–‰)

 

fine tuning

pre-trained model์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์•„ํ‚คํ…์ณ๋ฅผ ์ƒˆ๋กœ์šด ๋ชฉ์ ์— ๋งž๊ฒŒ ๋ณ€ํ˜•ํ•˜๊ณ  ์ด๋ฏธ ํ•™์Šต๋œ ๋ชจ๋ธ Weights๋กœ๋ถ€ํ„ฐ ํ•™์Šต์„ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๋ฐฉ๋ฒ•

→ pre-trained model์˜ weights๋ฅผ ๋ฏธ์„ธํ•˜๊ฒŒ ์กฐ์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•

 

๋ฐฉ๋ฒ• 3๊ฐ€์ง€

input ์ชฝ ์‚ฌ๊ฐํ˜• : conv base

prediction ์ชฝ ์‚ฌ๊ฐํ˜• : classifier

 

 

  • Strategy 1 : ์ „์ฒด ๋ชจ๋ธ์„ ์ƒˆ๋กœ ํ•™์Šต
    ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ์˜ ๊ตฌ์กฐ๋งŒ ์‚ฌ์šฉํ•˜๋ฉด์„œ ์ž์‹ ์˜ ๋ฐ์ดํ„ฐ์…‹์— ๋งž๊ฒŒ ์ „๋ถ€ ์ƒˆ๋กœ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•
    → ํฌ๊ธฐ๊ฐ€ ํฌ๊ณ  ์œ ์‚ฌ์„ฑ์ด ๋‚ฎ์€ ๋ฐ์ดํ„ฐ์…‹์ผ ๋•Œ ์ถ”์ฒœ

 

  • Strategy 2 : Convolutional base ์ผ๋ถ€๋ถ„์€ ๊ณ ์ •, ๋‚˜๋จธ์ง€ ๊ณ„์ธต๊ณผ Classifier๋ฅผ ์ƒˆ๋กœ ํ•™์Šต
    • ๋‚ฎ์€ ๋ ˆ๋ฒจ์˜ ๊ณ„์ธต → ์ผ๋ฐ˜์ ์ธ ํŠน์ง• ์ถ”์ถœ
    • ๋†’์€ ๋ ˆ๋ฒจ์˜ ๊ณ„์ธต → ๊ตฌ์ฒด์ ์ด๊ณ  ํŠน์œ ํ•œ ํŠน์ง•(Task์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์ง€๋Š” ํŠน์ง•) ์ถ”์ถœ
    ์–ด๋Š ์ •๋„๊นŒ์ง€ ์žฌํ•™์Šต ์‹œํ‚ฌ์ง€ ์ •ํ•จ
    → ํฌ๊ธฐ๊ฐ€ ํฌ๊ณ  ์œ ์‚ฌ์„ฑ์ด ๋†’์€ ๋ฐ์ดํ„ฐ์…‹์ผ ๋•Œ ์ถ”์ฒœ
    → ํฌํ‚ค๊ฐ€ ์ž‘๊ณ  ์œ ์‚ฌ์„ฑ๋„ ๋‚ฎ์„ ๋•Œ์—๋Š” ์œ„ ์ƒํ™ฉ๋ณด๋‹ค ์กฐ๊ธˆ ๋” ๊นŠ๊ฒŒ ์žฌํ•™์Šต

 

  • Strategy 3 : Convolutional base๋Š” ๊ณ ์ •์‹œํ‚ค๊ณ , Classifier๋งŒ ์ƒˆ๋กœ ํ•™์Šต
    • ์ปดํ“จํŒ… ์—ฐ์‚ฐ ๋Šฅ๋ ฅ์ด ๋ถ€์กฑํ•  ๋•Œ
    • ๋ฐ์ดํ„ฐ ์…‹์ด ๋„ˆ๋ฌด ์ž‘์„ ๋•Œ
    • ์ ์šฉํ•˜๋ ค๋Š” Task๊ฐ€ ํ•™์Šต๋ชจ๋ธ์ด ์ด๋ฏธ ํ•™์Šตํ•œ ๋ฐ์ดํ„ฐ ์…‹๊ณผ ๋งค์šฐ ๋น„์Šทํ•  ๋•Œ

 

strategy 3, 2, 1 ์ˆœ

 

tensorflow

  • tensorflow๊ฐ€ ์ œ๊ณตํ•˜๋Š” pre-trained model

์šฐ๋ฆฌ๊ฐ€ ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ์…‹์—์„œ ์–ด๋–ค ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ๊ฐ€์žฅ ์ข‹์„์ง€ ๋ชจ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค์–‘ํ•œ ๋ชจ๋ธ์„ ํ•™์Šต์‹œ์ผœ๋ณด๊ณ  ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ตํ•ด๋ด์•ผ ํ•จ

  • ์‚ฌ์šฉ๋ฒ•

tensorflow์—์„œ transfer learning์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋‹ค์–‘ํ•œ pre-trained ๋ชจ๋ธ์„ import

 

๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ

  • weights : ์‚ฌ์ „ํ•™์Šต์— ์‚ฌ์šฉ๋œ ๋ฐ์ดํ„ฐ์…‹
  • include_top=False : ๋ชจ๋ธ์˜ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋งŒ ๊ฐ€์ ธ์˜ด (๋ถ„๋ฅ˜๊ธฐ๋Š” ๊ฐ€์ ธ์˜ค์ง€ ์•Š์Œ)
  • input_shape : ์ƒˆ๋กญ๊ฒŒ ํ•™์Šต์‹œํ‚ฌ ์ด๋ฏธ์ง€ ํ…์„œ ํฌ๊ธฐ

 

์ „์ดํ•™์Šต + ์ปค์Šคํ„ฐ๋งˆ์ด์ง• → ๋ชจ๋ธ ์ƒ์„ฑ

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from tensorflow.keras.applications import # ์‚ฌ์šฉํ•  pre-trained ๋ชจ๋ธ

# ์›๋ž˜ VGG16์˜ ๊ตฌ์กฐ ํ™•์ธ
model = VGG16()
model.summary()

# ์ˆ˜์ • ํ›„ ๊ตฌ์กฐ ํ™•์ธ
model = VGG16(input_shape=(32, 32, 3), include_top=False, weights='imagenet')
model.summary()

output = model.output # ๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜

x = GlobalAveragePooling2D()(output)
x = Dense(50, activation='relu')(x)
output = Dense(10, activation='softmax', name='output')(x) # softmax 10๊ฐœ๋กœ ๋ถ„๋ฅ˜

model = Model(inputs=model.input, outputs=output)
model.summary()

๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์•„๋ž˜์™€ ๊ฐ™์ด ์ˆ˜์ •

  • input_shape๋Š” (32, 32, 3)
  • classification layer๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  ์ปค์Šคํ„ฐ๋งˆ์ด์ง•
  • ๊ฐ€์ค‘์น˜๋Š” imagenet์œผ๋กœ ์ „์ดํ•™์Šต

์›๋ž˜ ๋ชจ๋ธ์˜ ๊ตฌ์กฐ
์ˆ˜์ •๋œ ๋ชจ๋ธ์˜ ๊ตฌ์กฐ

์›๋ž˜ ๋ชจ๋ธ ๊ตฌ์กฐ์—์„œ Flatten, Dense ์ œ๊ฑฐ → ์ˆ˜์ •๋œ ๋ชจ๋ธ

 

 

์œ„ ์ฝ”๋“œ์ฒ˜๋Ÿผ ์ˆ˜์ •ํ•  ์ˆ˜๋„ ์žˆ๊ณ  ์•„๋ž˜ ์ฝ”๋“œ์ฒ˜๋Ÿผ ํ•œ layer์”ฉ ์ถ”๊ฐ€ ๊ฐ€๋Šฅ

  • Flatten() : pre-trained ๋ชจ๋ธ๊ณผ ์ƒˆ๋กญ๊ฒŒ ์ถ”๊ฐ€ํ•˜๋Š” ๋ถ„๋ฅ˜๊ธฐ ์—ฐ๊ฒฐ
  • compile() : loss, optimizer, metric ์ •์˜

 

ํ•™์Šต

  • ImageDataGenerator & train, test data ์ฝ์–ด์˜ค๊ธฐ

 

  • ๋ชจ๋ธ ์ปดํŒŒ์ผ & ํ•™์Šต

 

  • ์†์‹ค & ์ •ํ™•๋„


 

์ž ์™€..

 

์•„์ฐธ์ฐธ ์ฐธ๊ณ  ๋งํฌ

๋”๋ณด๊ธฐ