6개의 클래스에 대해서 1,800개의 이미지를 전이 학습을 통해 학습하는 과정입니다. 다음은 데이터 전처리부터 모델 학습까지의 순서입니다.
- 부족한 데이터의 크기를 키우기 위해서 ImageDataGenerator를 정의해줍니다.
- 모델의 입력으로 사용할 이미지와 한 번에 학습할 학습 데이터의 크기인 batch size를 정의해줍니다.
- 사전에 수집한 이미지 파일을 train과 valid 데이터로 나눠줍니다. 이때 이미지의 크기와 batch size를 사용하여 나눠줍니다.
- 모델의 레이어를 정의해줍니다. (MobileNetV2, GlobalAveragePolling2D, Dropout, Dense)
- 모델 컴파일을 진행한 후에 모델 학습 시 사용할 콜백 함수들을 정의해줍니다.
- 모델 학습을 진행합니다.
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
모델 학습에 필요한 패키지들을 가져와줍니다.
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=30,
shear_range=0.3,
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.25,
)
valid_datagen = ImageDataGenerator(
rescale=1./255,
)
각 클래스마다 학습과 검증 데이터에 300장의 이미지를 사용하므로 keras에서 제공하는 ImageDataGenerator를 사용하여 데이터의 크기를 키워줍니다.
batch_size = 64
img_width = 128
img_height = 128
train_data = train_datagen.flow_from_directory(
'./train_set/',
batch_size=batch_size,
target_size=(img_width, img_height),
shuffle=True,
)
valid_data = valid_datagen.flow_from_directory(
'./test_set/',
target_size=(img_width, img_height),
batch_size=batch_size,
shuffle=False,
)
이미지의 크기는 180 X 180으로 정의해주고, batch_size는 64로 정의해줍니다. 이후 사전에 shutil 패키지를 사용하여 train, valid 데이터로 분류해놓은 데이터를 학습 데이터와 검증 데이터로 분류합니다.
Found 1500 images belonging to 6 classes.
Found 300 images belonging to 6 classes.
6개의 클래스에 대한 1,500개의 이미지와 300개의 검증 데이터가 분류되었습니다.
def visualize_images(images, labels):
figure, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 14))
classes = list(train_data.class_indices.keys())
img_no = 0
for i in range(3):
for j in range(3):
img = images[img_no]
label_no = np.argmax(labels[img_no])
ax[i,j].imshow(img)
ax[i,j].set_title(classes[label_no])
ax[i,j].set_axis_off()
img_no += 1
images, labels = next(train_data)
visualize_images(images, labels)
matplotlib을 사용하여 학습 데이터의 이미지들이 제대로 분류 되었는지 확인해줍니다.
base = MobileNetV2(input_shape=(img_width, img_height,3),include_top=False,weights='imagenet')
base.trainable = True
model = Sequential()
model.add(base)
model.add(GlobalAveragePooling2D())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(6, activation='softmax'))
opt = Adam(learning_rate=0.001)
model.compile(optimizer=opt,loss = 'categorical_crossentropy',metrics=['accuracy'])
위의 코드와 같이 모델을 구성하고, 이때 optimizer는 Adam을 사용합니다. 학습률인 0.001로 설정하고, loss function은 다중 분류에 자주 사용되는 categorical_crossentropy를 설정해줍니다. 마지막으로 모델 측정 항목은 accuracy로 설정해줍니다.
reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy',patience = 1,verbose = 1)
early_stop = EarlyStopping(monitor = 'val_accuracy',patience = 5,verbose = 1,restore_best_weights = True)
check_point = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True)
다음은 콜백 함수들입니다. ReduceLROnPlateau는 모델 개선이 이뤄지지 않을 경우 학습률을 재설정하여 모델의 개선을 유도하는 콜백 함수입니다. EarlyStopping 과대 적합을 방지하기 위해서 5번의 epoch 동안 모델 개선이 없다면 학습을 종료시켜주는 콜백 함수입니다. ModelCheckpoint는 학습 중에 가장 정확도가 높은 모델의 데이터를 저장해주는 콜백 함수입니다.
history = model.fit(train_data, epochs=50, validation_data = valid_data, callbacks=[early_stop,reduce_lr,check_point])
위의 코드는 모델 학습을 진행하는 코드입니다. 학습 데이터와 epochs(모델 학습 횟수), 검증 데이터와 콜백함수들을 설정해줍니다.
Epoch 00013: val_accuracy did not improve from 0.78333
Epoch 34/50
24/24 [==============================] - 121s 5s/step - loss: 0.4256 - accuracy: 0.8473 - val_loss: 0.7780 - val_accuracy: 0.7833
Restoring model weights from the end of the best epoch.
Epoch 00034: ReduceLROnPlateau reducing learning rate to 1.000000032889008e-20.
Epoch 00034: val_accuracy did not improve from 0.78333
Epoch 00034: early stopping
EarlyStopping에 종료된 마지막 모델 학습의 결과를 나타낸 것입니다. 이때 모델의 학습률은 1.000000032889008e-20이고, 검증 데이터 모델 정확도는 0.7833입니다.
'Artificial Intelligence > Tensorflow' 카테고리의 다른 글
[Tensorflow] shutil을 사용한 데이터 분류 (0) | 2021.05.19 |
---|---|
[Tensorflow] 다중 선형회귀분석이 뭐야? (0) | 2021.05.14 |
[Tensorflow] 경사 하강법이 뭐야? (0) | 2021.05.07 |
[Tensorflow] 선형회귀분석이 뭐야? (0) | 2021.04.28 |
[Tensorflow] 김치 이미지 분류기 - 모델학습 (0) | 2021.04.23 |