Mạng MobileNet là gì
Xin chào tuần mới các mem, hôm nay chúng ta thử đóng vai một lazy boy thử xây dựng hệ thống tự động chỉnh màu cho ảnh nhá. Show Chả là anh chàng lazy boy có một cô người yêu khó tính gửi cho 1 tập ảnh và yêu cầu chỉnh màu cho cô. Vốn sẵn tính lười nên lazy boy muốn code ra một hệ thống có thể làm thay anh ta điều đó không thì chết với cô người yêu =)). Quan sát lại các ảnh mà cô người yêu đã sửa và up facebook trước đây, chàng lazy boy nhận ra rằng hóa ra có quy luật cả. Mỗi loại ảnh thì cô nàng đều apply với một loại chỉnh màu riêng. Kaka. Thế là chàng lười mới quyết định tạo ra một phần mềm tự động chỉnh màu tuỳ vào ảnh cho nhanh gọn gồm 2 loại:
Rồi , tạm thế đã, nếu ổn thì sẽ làm với tất cả các loại ảnh của nàng gửi =)) Phần 1 Phân tích bài toán tự động chỉnh màu ảnhVới bài toán này ta có thể nghĩ tới các kỹ thuật như sau:
Chúng ta sẽ sử dụng bộ dữ liệu gồm các ảnh indoor và outdoor để train. Các bạn có thể tải bộ dữ liệu tại Thư viện Mì AI: https://miai.vn/thu-vien-mi-ai . Bạn xem video clip để biết cách tải về. Phần 2 Xây dựng các filter cho ảnhNhư đã nói ở trên, sau khi nhận diện được ảnh đưa vào là trong nhà hay ngoài trời thì model sẽ áp dụng các filter ảnh tương ứng (thêm màu ấm, sapie). Như vậy chúng ta phải viết sẵn các filter này cho model, việc này đơn giản bằng opencv thuần thôi. Ví dụ đây là 2 filter cần dùng trong bài: # Hàm tạo filter sapie
def apply_sepia(frame, intensity=0.5):
frame = verify_alpha_channel(frame)
frame_h, frame_w, frame_c = frame.shape
sepia_bgra = (20, 66, 112, 1)
overlay = np.full((frame_h, frame_w, 4), sepia_bgra, dtype='uint8')
cv2.addWeighted(overlay, intensity, frame, 1.0, 0, frame)
return frame
# Hàm tạo filter màu
def apply_color_overlay(frame, intensity=0.5, blue=0, green=0, red=0):
frame = verify_alpha_channel(frame)
frame_h, frame_w, frame_c = frame.shape
sepia_bgra = (blue, green, red, 1)
overlay = np.full((frame_h, frame_w, 4), sepia_bgra, dtype='uint8')
cv2.addWeighted(overlay, intensity, frame, 1.0, 0, frame)
return frame Code language: PHP (php)
Trong project của mình, các filter này được lưu vào file filters.py để các file khác có thể import và sử dụng cho nhanh gọn. Phần 3 Train model nhận diện cảnh trong/ngoàiChuẩn bị dữ liệuBài này mình sẽ sử dụng kỹ thuật image generator. Đây là một kỹ thuật thông dụng mà mình đã sử dụng trong rất nhiều bài. Thư mục data của chúng ta sẽ chứa 2 thư mục con:
Dữ liệu này các bạn tải tại Thư viện Mì AI: https://miai.vn/thu-vien-mi-ai . Bạn xem video clip để biết cách tải về. Bây giờ ta sẽ sử dụng Image Generator để load dữ liệu từ 2 folder trên, sử dụng tên folder làm nhãn: # Đường dẫn đến folder ảnh
data_dir = pathlib.Path('./data')
# Tên class lấy bằng đúng tên thư mục (indoor, outdoor)
class_names = np.array([folder.name for folder in data_dir.glob('*') if folder.name != ".DS_Store"])
# Tạo ra một image_gen, có thực hiện rescale
image_generator = ImageDataGenerator(rescale=1/255, validation_split=0.2)
train_data_gen = image_generator.flow_from_directory(directory=str(data_dir), batch_size=batch_size,
classes=list(class_names), target_size=(input_size[0], input_size[1]),
shuffle=True, subset="training")
test_data_gen = image_generator.flow_from_directory(directory=str(data_dir), batch_size=batch_size,
classes=list(class_names), target_size=(input_size[0], input_size[1]),
shuffle=True, subset="validation") Code language: PHP (php)
Các bạn để ý mình có dùng rescale để đưa các giá trị trong ảnh về khoảng [0,1] nhé. Rồi, data vậy là okie, nếu các bạn chạy thành công sẽ thấy hiện ra màn hình: Found 640 images belonging to 2 classes.
Found 160 images belonging to 2 classes.
Dòng trên là số ảnh cho train và dòng thứ 2 là số ảnh cho validation nhé. Tạo cấu trúc modelTrong bài này mình sẽ sử dụng mạng MobilenetV2 và có xào nấu tý để ghép nối thành mạng của mình: # Load MobileNetV2
mobilenet_model = MobileNetV2(input_shape=input_shape)
# Bỏ đi layer cuối cùng (FC)
mobilenet_model.layers.pop()
# Đóng băng các layer (trừ 4 layer cuối)
for layer in mobilenet_model.layers[:-4]:
layer.trainable = False
mobilenet_output = mobilenet_model.layers[-1].output
# Tạo các layer mới
output = Dense(num_classes, activation="softmax")
# Lấy input từ output của MobileNet
output = output(mobilenet_output)
# Tạo model với input của MobileNet và output là lớp Dense vừa thêm
model = Model(inputs=mobilenet_model.inputs, outputs=output)
# In cấu trúc mạng
model.summary()
# Compile model
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) Code language: PHP (php)
Và cấu trúc mạng của chúng ta sẽ được in ra màn hình: __________________________________________________________________________________________________
dense_1 (Dense) (None, 2) 2562 global_average_pooling2d_1[0][0]
==================================================================================================
Total params: 2,260,546
Trainable params: 414,722
Non-trainable params: 1,845,824
__________________________________________________________________________________________________
Các bạn để ý model Dense mà ta đã thêm vào đã đúng. Tổng số tham số của mạng là 2,260,546 trong đó có 414,722 đã fix cứng, còn lại 1,845,824 tham số sẽ được train. Train modelOkie rồi, train thôi các bạn: # Định nghĩa batch_size và epochs
batch_size = 32
epochs = 1
# Load các data gen
train_generator, validation_generator, class_names = get_generator()
# Tạo model
model = get_model(num_classes = len(class_names))
# Tạo callback để lấy weight mới nhất
checkpoint = ModelCheckpoint("models/my_model-" + "-loss-{val_loss:.2f}-acc-{val_accuracy:.2f}.h5", save_best_only=True, verbose=1)
# Train model
training_steps_per_epoch = np.ceil(train_generator.samples / batch_size)
validation_steps_per_epoch = np.ceil(validation_generator.samples / batch_size)
model.fit_generator(train_generator, steps_per_epoch=training_steps_per_epoch,
validation_data=validation_generator, validation_steps=validation_steps_per_epoch,
epochs=epochs, verbose=1, callbacks=[ checkpoint])
# Lưu model sau khi train xong all epochs
model.save("models/my_model.h5") Code language: PHP (php)
Ở đây các bạn cần chú ý 2 vấn đề:
Phần 4 Kiểm thử model tự động chỉnh màu ảnhBây giờ chúng ta thử nghiệm xem model chạy thế nào với dữ liệu thực tế nào. Chúng ta sẽ đưa vào một bức ảnh, đọc ảnh và chuyển thành tensor (nhớ rescale / 255) và predict: # Thay đổi ảnh ở đây
image_path = "test_data/b.jpg"
# Đọc ảnh
image = cv2.imread(image_path)
image_org = image.copy()
# Chuyển đổi thành tensor
image = cv2.resize(image, dsize=input_size[:2])
image = image/255
image = np.expand_dims(image, axis=0)
# Tạo model
model = get_model()
# load the optimal weights
model.load_weights("models/my_model--loss-0.41-acc-0.94.h5")
# Tiến hành predict
class_names = ['indoor','landscape']
output = model.predict(image) Code language: PHP (php)
Output của chúng ta sẽ là một vector $p$ như sau: $$p = \begin{bmatrix} p1 & p2 \end{bmatrix}$$ Trong đó :
Do đó chúng ta chỉ cần dùng $np.argmax$ là có thể lấy được giá trị 0,1 để biết ảnh hiện tại đang là class nào: class_name = class_names[np.argmax(output)]
Và bước cuối cùng, khi đã biết class ảnh ta sẽ áp dụng các filter tương ứng:
# Nếu là landscape thì apply overlay
if class_name == "landscape":
filter_image = filters.apply_color_overlay(image_org, intensity=.2, red=250, green=100, blue=0)
cv2.putText(filter_image,class_name,(50,50),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
else:
# Nếu là indoor thì apply sepia
filter_image = filters.apply_sepia(image_org, intensity=.8)
cv2.putText(filter_image, class_name, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
cv2.imshow('orginal_color', image_org)
cv2.imshow('color_overlay', filter_image)
cv2.waitKey()
cv2.destroyAllWindows() Code language: PHP (php)
Ở đây mình sẽ thử với 2 ảnh KHÔNG CÓ TRONG TẬP TRAIN để xem model như nào. Và kết quả khá ổn! Với ảnh ngoài trời thì sẽ được thêm chút nắng ấm xa dần: Còn nếu là ảnh giường chiếu/trong phòng thì thêm tý hiệu ứng Sapie nào cho nó ấm cúng (anh em nào nhận ra ảnh này quen ko nhở ) Mình có code viết sẵn kèm file pretrain h5 tại github này nhé. Các bạn có thể tải về để thử luôn. Chúc các bạn thành công! #MìAI Fanpage:http://facebook.com/miaiblog Cảm ơn bài tham khảo tuyệt vời tại đây. |