앞으로 딥러닝 코드를 이해하기 위해서는 텐서를 자유롭게 다룰 수 있어야 한다.

 

텐서를 다루기 위해 Pytorch를 사용하게 될 텐데, 텐서를 다루는기 위한 Pytorch의 기초적인 문법에 대해서 작성해 보겠다.

 

1. dim()

2. shape()

3. size()

4. view()

5. squeeze()

6. unsqueeze()

7. cat()

8. ones_like()

9. zeros_like()

 

import torch
import numpy as np
numpy 선언 이후, Tensor로 변환한다. (Tensor로 변환함으로써 다양한 매소드를 지원할 수 있게 된다.)
t = torch.FloatTensor([0., 1., 2., 3., 4., 5., 6.])

print(t.dim())    # 1
print(t.shape()) # torch.Size([7])
print(t.size())    # torch.Size([7])
view()를 통해 원하는 텐서 크기로 변환할 수 있다. (-'1 은 알아서 맞춰라' 라는 의미를 갖고 있다.)
t = np.array([[[0, 1, 2],
               [3, 4, 5]],
              [[6, 7, 8],
               [9, 10, 11]]])

ft = torch.FloatTensor(t)

print(ft.shape)                  # torch.Size([2, 2, 3])
print(ft.view([-1,3]).shape)   # torch.Size([4, 3])
print(ft.view([-1,1,3]).shape) # torch.Size([4, 1, 3])
squeeze()는 1인 차원을 제거한다.
ft = torch.FloatTensor([[0],[1],[2]])

print(ft.shape)              # torch.Size([3, 1])
print(ft.squeeze().shape) # torch.Size([3])
 
unsqueeze()는 원하는 위치에 차원을 추가한다. (view()를 통해 동일한 차원을 만들 수도 있다.)
ft = torch.Tensor([0,1,2])

print(ft.shape)                    # torch.Size([3])

print(ft.unsqueeze(0).shape)  # torch.Size([1, 3])
print(ft.view(1,-1).shape)       # torch.Size([1, 3])

print(ft.unsqueeze(1).shape)  # torch.Size([3, 1])

print(ft.unsqueeze(-1).shape) # torch.Size([3, 1])
cat()을 통해 두 개의 텐서를 연결할 수 있다. (dim=0은 행으로 추가하고, dim=1은 열로 추가하라는 의미이다.)
x = torch.FloatTensor([[1,2],[3,4]])
y = torch.FloatTensor([[5,6],[7,8]])

print(torch.cat([x,y], dim=0).shape) # torch.Size([2, 4])
print(torch.cat([x,y], dim=1).shape) # torch.Size([4, 2])
ones_like()와 zeros_like()를 통해 매개변수와 동일한 shape을 갖는 텐서를 생성한다. (ones -> 1, zeros -> 0)
x = torch.FloatTensor([[0, 1, 2], [2, 1, 0]])

print(torch.ones_like(x)) 
print(torch.zeros_like(x))

'개인공부 > Python' 카테고리의 다른 글

[vscode] Github 연동 (git clone)  (0) 2022.01.20
im2col 함수  (0) 2022.01.18
RuntimeError (CUDA)  (0) 2022.01.13
Python, Pytorch, Torchvision 다운그레이드  (0) 2022.01.11
파라미터 저장을 위한 예제 (hdf5)  (0) 2021.12.20

+ Recent posts