SRCNN을 구현해보고 소스코드에 대해 이해한 내용과 아직 이해하지 못한 내용을 정리해본다.

 

Training

우선 Training을 위한 코드에 대해 이해하도록 해보자.

전체 소스코드는 아래 py파일을 확인해보자.

train.py
0.00MB

 

Argparse 설정
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-file', type=str, default='./train_file')
    parser.add_argument('--eval-file', type=str, default='./eval_file')
    parser.add_argument('--outputs-dir', type=str, default='./outputs')
    parser.add_argument('--num-channels', type=int, default=3)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--psnr-lr', type=float, default=1e-3)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--num-epochs', type=int, default=10)
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--patch-size', type=int, default=160)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--checkpoint-file', type=str, default='checkpoint-file.pth')
    args = parser.parse_args()
위의 소스코드에서 맨 마지막 args를 print해보면 default에 입력된 인자들이 나열된다.

필요할 때마다 "args.train_file", "args.patch_size" 와 같은 형태로 인스턴스 부르듯이 인자들을 호출하면 된다.

 

weight를 저장 할 경로 설정
args.outputs_dir = os.path.join(args.outputs_dir,  f"SRCNNx{args.scale}")
if not os.path.exists(args.outputs_dir):       
    os.makedirs(args.outputs_dir)

위의 소스코드는 아주 유용한 소스코드인데, args.outputs_dir의 인자에 f"SRCNNx{args.scale}" 경로를 추가해준다.

 

그리고 경로가 존재하지 않는다면 os.makedirs() 함수를 통해 경로를 생성해준다. (이를 통해 경로로 인한 에러 문제를 자연스럽게 해결해준다.)

 

GPU 디바이스 설정
cudnn.benchmark = True # 잘 모르겠다.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

device에 cuda를 사용한다는 정보를 입력한다.

 

Torch Seed 설정
torch.manual_seed(args.seed)

args.seed의 값 자체는 의미는 없고, 랜덤 변수를 생성하기 위한 구문이라고 이해하면 될 듯 하다.

 

model 설정
model = SRCNN(num_channels=args.num_channels, scale=args.scale).to(device)

SRCNN을 사용하고, 입력 채널과 크기를 파라미터로 입력한다.

--> SRCNN(입력채널, x2 업스케일링)

Loss 및 optimizer 설정
pixel_criterion = nn.MSELoss().to(device)
psnr_optimizer = torch.optim.Adam(model.parameters(), args.psnr_lr, (0.9, 0.999))

total_epoch = args.num_epochs
start_epoch = 0
best_psnr = 0

 

체크포인트 weight 불러오기
if os.path.exists(args.checkpoint_file):
    checkpoint = torch.load(args.checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])
    psnr_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']
    best_psnr = checkpoint['best_psnr']

위의 소스코드에서는 원하는 epoch에서 중단하고 다시 실행하기 위한 구문이다.

실제로 torch.savetorch.load가 어떻게 저장되고 읽는지 파악하기 좋은 구문이라고 생각된다.

checkpoint는 dict으로 저장되는데, keys()를 통해 어떻게 생겼는지 확인해보면 단숨에 이해가 된다.

 

Example) torch.load의 형태 파악
checkpoint.keys() #dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss', 'best_psnr'])
checkpoint['epoch'] #몇 번째 epoch에서 시작할지 확인할 수 있다.
checkpoint['model_state_dict'] #weight와 bias가 어떻게 저장되었는지 확인할 수 있다.
checkpoint['optimizer_state_dict'] #optimizer의 값들을 확인할 수 있다.
checkpoint['loss'] #현재 학습된 loss의 값을 확인할 수 있다.
checkpoint['best_psnr'] #현재 학습된 최적의 psnr을 확인할 수 있다.

추가적으로 모델을 저장하는 2가지 방법에 대해서 이해하고 넘어가자

 

  • torch.save(model명, PATH + 'model.pth') / torch.load(PATH + 'model.pth')
  • torch.save([model명].state_dict(), PATH + 'model.pth') / torch.load_state_dict(torch.load(PATH + 'model.pth'))

그냥 torch.save는 모델의 모든 정보를 저장하게 된다. (저장되는 정보는 위의 Example을 확인해보면 된다.)

 

state_dict()은 weigth와 bias정도만 저장된다고 한다.

 

그래서 state_dict()이 더 작은 용량을 갖게 되고, 보통 test를 위해서는 state_dict()의 정보만 load한다고 한다.

 

참고: https://tutorials.pytorch.kr/recipes/recipes/what_is_state_dict.html

 

Log 정보 프린트
logger.info(
                f"SRCNN MODEL INFO:\n"
                f"\tNumber of channels:               {args.num_channels}\n"
                f"SRCNN TRAINING INFO:\n"
                f"\tTotal Epoch:                          {args.num_epochs}\n"
                f"\tStart Epoch:                          {start_epoch}\n"
                f"\tTrain directory path:                {args.train_file}\n"
                f"\tTest directory path:                 {args.eval_file}\n"
                f"\tOutput weights directory path:  {args.outputs_dir}\n"
                f"\tPSNR learning rate:                 {args.psnr_lr}\n"
                f"\tPatch size:                             {args.patch_size}\n"
                f"\tBatch size:                             {args.batch_size}\n"
                )

 

스케줄러 설정 (https://wikidocs.net/157282)
psnr_scheduler = torch.optim.lr_scheduler.StepLR(psnr_optimizer, step_size=30, gamma=0.1)    scaler = amp.GradScaler()
  • optimizer: 이전에 정의한 optimizer 변수명을 넣어준다.
  • step_size: 몇 epoch마다 lr을 감소시킬지가 step_size를 의미한다.
  • gamma: gamma 비율로 lr을 감소시킨다.

 

데이터셋 & 데이터셋 설정
train_dataset = Dataset(args.train_file, args.patch_size, scale=args.scale)
train_dataloader = DataLoader(
                            dataset=train_dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True
                            )
eval_dataset = Dataset(args.eval_file, args.patch_size, scale=args.scale)
eval_dataloader = DataLoader(
                                dataset=eval_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers,
                                pin_memory=True
                                )

내가 지금 소스코드에서 가장 많은 시간을 들여 이해한 부분이다.

 

디테일한 부분은 다음에 따로 설명하도록 하고, 어떤 구조를 갖는지에 대해서만 설명하겠다.

 

Example) dataset의 형태 파악
train_dataset.__len__() # 700 --> 읽어들인 이미지의 개수
np.array(train_dataset).shape # (700, 2) --> 700이미지가 2종류(HR이미지와 LR이미지)로 존재
np.array(train_dataset[0]).shape # (, 2) --> 700이미지중 1번째 이미지가 2종류(HR/LR)로 존재
np.array(train_dataset[0][0]).shape # (3, 80, 80) --> 1번째 이미지는 (C,H,W) 순으로 존재 (채널, 세로, 가로)

eval_dataset.__len__() # 100 --> 읽어들인 이미지의 개수
Example) dataloader의 형태 파악
batch_iterator = iter(train_dataloader)
lr_images, hr_images = next(batch_iterator) # 

 

우선 내가 이해한 부분은 여기까지이고, 이후 다른 코드는 크게 어려운 부분은 없었다.

 

혹시라도 이글을 읽는 분들중에 틀린부분이나 모르는부분은 댓글로 남겨주시면 감사하겠습니다.

'연구내용 > DeepLearning' 카테고리의 다른 글

DCN MNIST 구현 패키지  (0) 2022.03.24
VSRnet_torch 패키지  (0) 2022.02.09
VSRnet (keras) 코드 및 패키지  (0) 2022.02.07
model.train(), model.eval() 의미  (0) 2022.02.03

+ Recent posts