Transfer Learning
ํ์ต ๋ฐ์ดํฐ๊ฐ ๋ถ์กฑํ ๋ถ์ผ์ ๋ชจ๋ธ ๊ตฌ์ถ์ ์ํด ๋ฐ์ดํฐ๊ฐ ํ๋ถํ ๋ถ์ผ์์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ฌ์ฌ์ฉํ๋ ํ์ต ๊ธฐ๋ฒ
Imagenet(๋๊ท๋ชจ ๋ฐ์ดํฐ์ ) ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํด์ ์ฌ์ ํ์ต๋(pre-trained) ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง๊ณ ์์ ์ฐ๋ฆฌ๊ฐ ํด๊ฒฐํ๊ณ ์ ํ๋ ๊ณผ์ ์ ๋ง๊ฒ ์ฌ๋ณด์ ํด์ ์ฌ์ฉ
๋น๊ต์ ์ ์ ์์ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ง๊ณ ๋ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๊ณผ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ๋ จ
→ pre-trained model์ weights๋ฅผ ์ฝ๊ฐ์ฉ ๋ณํ์์ผ ์ ์ ๋ฐ์ดํฐ์ ์์ task์ ๋ง๊ฒ ์ฌ์ฌ์ฉ
→ pre-trained model์ classifier๋ ์ญ์ ํ๊ณ ๋ชฉ์ ์ ๋ง๋ ์๋ก์ด classifier ์ถ๊ฐ
⇒ ์๋กญ๊ฒ ๋ง๋ค์ด์ง ๋ชจ๋ธ fine tuning ์งํ (strategy 3๊ฐ ์ค 1๊ฐ ์ ํํด์ ์งํ)
fine tuning
pre-trained model์ ๊ธฐ๋ฐ์ผ๋ก ์ํคํ ์ณ๋ฅผ ์๋ก์ด ๋ชฉ์ ์ ๋ง๊ฒ ๋ณํํ๊ณ ์ด๋ฏธ ํ์ต๋ ๋ชจ๋ธ Weights๋ก๋ถํฐ ํ์ต์ ์ ๋ฐ์ดํธํ๋ ๋ฐฉ๋ฒ
→ pre-trained model์ weights๋ฅผ ๋ฏธ์ธํ๊ฒ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ
๋ฐฉ๋ฒ 3๊ฐ์ง
input ์ชฝ ์ฌ๊ฐํ : conv base
prediction ์ชฝ ์ฌ๊ฐํ : classifier
- Strategy 1 : ์ ์ฒด ๋ชจ๋ธ์ ์๋ก ํ์ต
์ฌ์ ํ์ต ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ง ์ฌ์ฉํ๋ฉด์ ์์ ์ ๋ฐ์ดํฐ์ ์ ๋ง๊ฒ ์ ๋ถ ์๋ก ํ์ต์ํค๋ ๋ฐฉ๋ฒ
→ ํฌ๊ธฐ๊ฐ ํฌ๊ณ ์ ์ฌ์ฑ์ด ๋ฎ์ ๋ฐ์ดํฐ์ ์ผ ๋ ์ถ์ฒ
- Strategy 2 : Convolutional base ์ผ๋ถ๋ถ์ ๊ณ ์ , ๋๋จธ์ง ๊ณ์ธต๊ณผ Classifier๋ฅผ ์๋ก ํ์ต
- ๋ฎ์ ๋ ๋ฒจ์ ๊ณ์ธต → ์ผ๋ฐ์ ์ธ ํน์ง ์ถ์ถ
- ๋์ ๋ ๋ฒจ์ ๊ณ์ธต → ๊ตฌ์ฒด์ ์ด๊ณ ํน์ ํ ํน์ง(Task์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋ ํน์ง) ์ถ์ถ
→ ํฌ๊ธฐ๊ฐ ํฌ๊ณ ์ ์ฌ์ฑ์ด ๋์ ๋ฐ์ดํฐ์ ์ผ ๋ ์ถ์ฒ
→ ํฌํค๊ฐ ์๊ณ ์ ์ฌ์ฑ๋ ๋ฎ์ ๋์๋ ์ ์ํฉ๋ณด๋ค ์กฐ๊ธ ๋ ๊น๊ฒ ์ฌํ์ต
- Strategy 3 : Convolutional base๋ ๊ณ ์ ์ํค๊ณ , Classifier๋ง ์๋ก ํ์ต
- ์ปดํจํ ์ฐ์ฐ ๋ฅ๋ ฅ์ด ๋ถ์กฑํ ๋
- ๋ฐ์ดํฐ ์ ์ด ๋๋ฌด ์์ ๋
- ์ ์ฉํ๋ ค๋ Task๊ฐ ํ์ต๋ชจ๋ธ์ด ์ด๋ฏธ ํ์ตํ ๋ฐ์ดํฐ ์ ๊ณผ ๋งค์ฐ ๋น์ทํ ๋
tensorflow
- tensorflow๊ฐ ์ ๊ณตํ๋ pre-trained model
์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ ๋ฐ์ดํฐ์ ์์ ์ด๋ค ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ๊ฐ์ฅ ์ข์์ง ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์ ๋ค์ํ ๋ชจ๋ธ์ ํ์ต์์ผ๋ณด๊ณ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํด๋ด์ผ ํจ
- ์ฌ์ฉ๋ฒ
tensorflow์์ transfer learning์ ์ฌ์ฉํ๊ธฐ ์ํด์ ๋ค์ํ pre-trained ๋ชจ๋ธ์ import
๋ชจ๋ธ ํ๋ผ๋ฏธํฐ
- weights : ์ฌ์ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์
- include_top=False : ๋ชจ๋ธ์ ํน์ง ์ถ์ถ๊ธฐ๋ง ๊ฐ์ ธ์ด (๋ถ๋ฅ๊ธฐ๋ ๊ฐ์ ธ์ค์ง ์์)
- input_shape : ์๋กญ๊ฒ ํ์ต์ํฌ ์ด๋ฏธ์ง ํ ์ ํฌ๊ธฐ
์ ์ดํ์ต + ์ปค์คํฐ๋ง์ด์ง → ๋ชจ๋ธ ์์ฑ
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten
from tensorflow.keras.applications import # ์ฌ์ฉํ pre-trained ๋ชจ๋ธ
# ์๋ VGG16์ ๊ตฌ์กฐ ํ์ธ
model = VGG16()
model.summary()
# ์์ ํ ๊ตฌ์กฐ ํ์ธ
model = VGG16(input_shape=(32, 32, 3), include_top=False, weights='imagenet')
model.summary()
output = model.output # ๋ง์ง๋ง ์ถ๋ ฅ๊ฒฐ๊ณผ ๋ฐํ
x = GlobalAveragePooling2D()(output)
x = Dense(50, activation='relu')(x)
output = Dense(10, activation='softmax', name='output')(x) # softmax 10๊ฐ๋ก ๋ถ๋ฅ
model = Model(inputs=model.input, outputs=output)
model.summary()
๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ์๋์ ๊ฐ์ด ์์
- input_shape๋ (32, 32, 3)
- classification layer๋ฅผ ์ ๊ฑฐํ๊ณ ์ปค์คํฐ๋ง์ด์ง
- ๊ฐ์ค์น๋ imagenet์ผ๋ก ์ ์ดํ์ต
์๋ ๋ชจ๋ธ ๊ตฌ์กฐ์์ Flatten, Dense ์ ๊ฑฐ → ์์ ๋ ๋ชจ๋ธ
์ ์ฝ๋์ฒ๋ผ ์์ ํ ์๋ ์๊ณ ์๋ ์ฝ๋์ฒ๋ผ ํ layer์ฉ ์ถ๊ฐ ๊ฐ๋ฅ
- Flatten() : pre-trained ๋ชจ๋ธ๊ณผ ์๋กญ๊ฒ ์ถ๊ฐํ๋ ๋ถ๋ฅ๊ธฐ ์ฐ๊ฒฐ
- compile() : loss, optimizer, metric ์ ์
ํ์ต
- ImageDataGenerator & train, test data ์ฝ์ด์ค๊ธฐ
- ๋ชจ๋ธ ์ปดํ์ผ & ํ์ต
- ์์ค & ์ ํ๋
์์ฐธ์ฐธ ์ฐธ๊ณ ๋งํฌ
's t u d y . . ๐ง > AI ์ค ML ์ค DL' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[NLP | BERT & SBERT] Cross-Encoder์ Bi-Encoder (0) | 2023.04.27 |
---|---|
[PyTorch] iris ๋ฐ์ดํฐ ๋ถ๋ฅ ~ (w/๋ฉํฐ ํผ์ ํธ๋ก ) (0) | 2023.02.26 |
[YOLOv5] ํ์ต๋ ๋ชจ๋ธ๋ก ์ด๋ฏธ์ง test ํ ํ ํ์ผ์ ์ ์ฅํ๊ธฐ (0) | 2022.10.08 |
[YOLOv5] YOLOv5 ์ฌ์ฉ๋ฒ (1) | 2022.10.07 |
[YOLOv5] Custom Dataset์ผ๋ก Pothole detection (0) | 2022.10.07 |