실제로 CNN 네트워크를 구현하다보면 3차원 데이터를 다루는 것이 아닌 4차원 데이터를 다루게 된다.

 

이론만 접한 사람들에게는 '응? 이게 무슨 의미야' 라고 생각될 수 있다.

 

하지만 실제로 네트워크는 효율적으로 계산하기 위해서 4차원 데이터를 기반으로 CNN 네트워크를 동작시킨다.

torch.size([batch size, channel, height, width])

 

배치 사이즈가 맨 앞단에 존재하고, 그 뒤에 채널의 깊이를 의미하는 텐서를 다루게 된다.

 

이렇게 복잡해 보이는 4차원 데이터를 im2col이라는 함수를 적용하여 단순하게 계산할 수 있도록 한다.

 

우선 아래와 같은 간단한 예제를 들어 설명하겠다.

 

3차원 입력데이터를 필터링하기 좋게 전개하게 되는데, 구체적으로는 입력 데이터에서 필터를 적용하는 영역을 한 줄로 늘어놓게 된다.

 

이 전개를 필터에 적용하는 모든 영역에서 수행하는게 im2col이다.

 

실제로는 필터의 영역이 겹치는 경우가 대부분이라 im2col로 전개하게 되면 원래의 원소 수보다 더 많아지게 되는 단점이 있지만, 컴퓨터는 큰 행렬을 묶어서 계산하는데 최적화 되어 있으므로 빠르게 계산할 수 있다는 장점을 갖게 된다.

 

추가적으로 im2col에 대해서 그래픽하게 잘 표현해놓은 사이트를 공유한다.

https://hackmd.io/@bouteille/blog-post-cnnumpy-fast#

+ Recent posts