이번에 새롭게 공부한 이미지 데이터 augmentation
작은 물체에 대해 잘 탐지를 하지 못해 zoom out을 통한 augmentation과 rotation을 통해 4만장의 데이터를 2배로 증가시켰다.
import os
import glob
import random
import cv2
import numpy as np
from tqdm import tqdm
import albumentations as A
import logging
import time
from datetime import datetime
# Set Log H
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
log_filename = f'logs/image_processing_{current_time}.log'
# logs 디렉토리 생성
if not os.path.exists('logs'):
os.makedirs('logs')
# Logger 설정
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# 파일 핸들러 설정
file_handler = logging.FileHandler(log_filename)
file_handler.setLevel(logging.INFO)
# 로그 포맷 설정
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# 핸들러를 로거에 추가
logger.addHandler(file_handler)
# Function to save images
def save_image(image, path):
image = image.astype(np.uint8) # Ensure the image is in the correct format
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV
cv2.imwrite(path, image)
logger.info(f"Saved image: {path}")
# Load processed files to resume from where it left off
def load_processed_files(log_filename):
if os.path.exists(log_filename):
with open(log_filename, 'r') as f:
processed_files = f.read().splitlines()
else:
processed_files = []
return processed_files
# Function to perform image augmentation
def zoom_out_pad_rotate(image, bboxes, class_labels, scale_range=(0.4, 0.8)):
h, w, _ = image.shape
scale = random.uniform(*scale_range)
new_h, new_w = int(h * scale), int(w * scale)
augmented_images = []
angle = random.uniform(-90, 90)
transform = A.Compose([
A.Rotate(limit=(-90, 90), p=1), # -90도에서 90도까지 랜덤 회전
A.HorizontalFlip(p=0.5),
A.Affine(
scale=[0.5, 1.2],
p=1, keep_ratio=True,
mode=cv2.BORDER_REPLICATE),
A.GaussianBlur(p=0.5),
A.GaussNoise(p=0.6),
A.RGBShift(p=0.5),
A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3, p=0.5),
A.CoarseDropout(p=1, num_holes=50, max_h_size=30, max_w_size=30)
], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
augmented = transform(image=image, bboxes=bboxes, class_labels=class_labels)
augmented_images.append((augmented['image'], augmented['bboxes'], augmented['class_labels']))
return augmented_images
# Main processing function
def process_images(folder, label_folder, scale_range=(0.4, 0.8), max_images=10):
start_time = time.time()
image_extensions = ['jpg', 'jpeg', 'png']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(folder, f'*.{ext}')))
total_images = len(image_files)
logger.info(f"Total images in folder: {total_images}")
output_img_dir = 'zoomed_out_images_test'
output_label_dir = 'zoomed_out_labels_test'
if not os.path.exists(output_img_dir):
os.makedirs(output_img_dir)
if not os.path.exists(output_label_dir):
os.makedirs(output_label_dir)
# Load processed files to resume
logger.info("Start to augmentation!")
processed_files = load_processed_files(log_filename)
for idx, file in enumerate(tqdm(image_files[:max_images], desc="Processing images", leave=True)):
if file in processed_files:
continue # Skip already processed files
# Load image
img = cv2.imread(file)
if img is None:
logger.warning(f"Failed to load image: {file}. Skipping.")
continue
# Load corresponding label file
file_name = os.path.splitext(os.path.basename(file))[0]
label_file = os.path.join(label_folder, f"{file_name}.txt")
if not os.path.exists(label_file):
logger.warning(f"Label file {label_file} does not exist. Skipping image {file}.")
continue
with open(label_file, 'r') as f:
lines = f.readlines()
bboxes = []
class_labels = []
for line in lines:
temp = line.strip().split(" ")
class_labels.append(int(temp[0]))
bboxes.append([float(temp[1]), float(temp[2]), float(temp[3]), float(temp[4])])
augment_start_time = time.time()
augmented_images = zoom_out_pad_rotate(img, bboxes, class_labels, scale_range)
augment_end_time = time.time()
augment_time = augment_end_time - augment_start_time
logger.info(f"Time taken for augmentation: {augment_time:.2f} seconds")
for aug_idx, (aug_img, aug_bboxes, aug_class_labels) in enumerate(augmented_images):
save_img_path = os.path.join(output_img_dir, f"{file_name}_aug_{aug_idx}.jpg")
save_label_path = os.path.join(output_label_dir, f"{file_name}_aug_{aug_idx}.txt")
save_image(aug_img, save_img_path)
with open(save_label_path, 'w') as f:
for bbox, cls in zip(aug_bboxes, aug_class_labels):
f.write(f"{cls} {bbox[0]} {bbox[1]} {bbox[2]} {bbox[3]}\n")
logger.info(f"Saved label: {save_label_path}")
# Mark file as processed
with open(log_filename, 'a') as f:
f.write(file + '\n')
# 전체 이미지 개수와 각 이미지 크기 출력
logger.info(f"Total number of images processed: {total_images}")
end_time = time.time()
total_process_time = end_time - start_time
logger.info(f"Total time taken for processing: {total_process_time:.2f} seconds")
# Example usage
folder_path = ""
label_folder_path = ""
process_images(folder_path, label_folder_path, scale_range=(0.4, 0.8), max_images=40000)
이전에 잘 탐지못하던 이벤트도 일부 잘 탐지했음을 확인 !
하지만 여전히 화면을 zoom 해야 탐지 성능이 올라가는 것 같다.
추후 고려중인 방법으로는 생성형 AI를 통해 추가 이미지 확보 / 모델 변경이 필요할 것 같다
300x250
'AI' 카테고리의 다른 글
Session-based recommendations with RNN (0) | 2023.07.16 |
---|