실제로 CNN 네트워크를 구현하다보면 3차원 데이터를 다루는 것이 아닌 4차원 데이터를 다루게 된다.
이론만 접한 사람들에게는 '응? 이게 무슨 의미야' 라고 생각될 수 있다.
하지만 실제로 네트워크는 효율적으로 계산하기 위해서 4차원 데이터를 기반으로 CNN 네트워크를 동작시킨다.
torch.size([batch size, channel, height, width])
배치 사이즈가 맨 앞단에 존재하고, 그 뒤에 채널의 깊이를 의미하는 텐서를 다루게 된다.
이렇게 복잡해 보이는 4차원 데이터를 im2col이라는 함수를 적용하여 단순하게 계산할 수 있도록 한다.
우선 아래와 같은 간단한 예제를 들어 설명하겠다.
3차원 입력데이터를 필터링하기 좋게 전개하게 되는데, 구체적으로는 입력 데이터에서 필터를 적용하는 영역을 한 줄로 늘어놓게 된다.
이 전개를 필터에 적용하는 모든 영역에서 수행하는게 im2col이다.
실제로는 필터의 영역이 겹치는 경우가 대부분이라 im2col로 전개하게 되면 원래의 원소 수보다 더 많아지게 되는 단점이 있지만, 컴퓨터는 큰 행렬을 묶어서 계산하는데 최적화 되어 있으므로 빠르게 계산할 수 있다는 장점을 갖게 된다.
추가적으로 im2col에 대해서 그래픽하게 잘 표현해놓은 사이트를 공유한다.
'개인공부 > Python' 카테고리의 다른 글
[ImportError] DLL load failed: 지정된 모듈을 찾을 수 없습니다. (0) | 2022.01.20 |
---|---|
[vscode] Github 연동 (git clone) (0) | 2022.01.20 |
Pytorch - 텐서 조작하기 (0) | 2022.01.18 |
RuntimeError (CUDA) (0) | 2022.01.13 |
Python, Pytorch, Torchvision 다운그레이드 (0) | 2022.01.11 |