본문 바로가기

인공지능(AI) 이론과 코드/5. 컴퓨터 비전(CV)

torch.cat()과 torch.stack()의 차이점

파이토치에서 텐서들을 서로 병합하는(붙이는) 2가지 함수

torch.cat()과 torch.stack()의 차이점에 대해 알아봅니다.

 

  • torch.cat()은 주어진 차원을 기준으로 주어진 텐서들을 붙입니다(concatenate).

  • torch.stack()은 새로운 차원으로 주어진 텐서들을 붙입니다.

  • 따라서, (3, 4)의 크기(shape)를 갖는 2개의 텐서 A와 B를 붙이는 경우,
    torch.cat([A, B], dim=0)의 결과는 (6, 4)의 크기(shape)를 갖고,
    torch.stack([A, B], dim=0)의 결과는 (2, 3, 4)의 크기를 갖습니다.

 

  • 예를 들어 설명하기 위해, 아래 두 개의 텐서 t1, t2를 예시로 선언해보겠습니다.
    t1 = torch.tensor([[1, 2],
                       [3, 4]])
    t2 = torch.tensor([[5, 6],
                       [7, 8]])
    

 

  • 이때,torch.cat()의 동작은 다음과 같습니다.
    >>> torch.cat((t1, t2), dim=0) # dim=0인 경우
    tensor([[1, 2],
            [3, 4],
            [5, 6],
            [7, 8]])
    
    >>> torch.cat((t1, t2), dim=1) # dim=1인 경우
    tensor([[1, 2, 5, 6],
            [3, 4, 7, 8]])
    

 

  • torch.stack()은 다음과 같습니다.
    >>> torch.stack((t1, t2))
    tensor([[[1, 2],
             [3, 4]],
     
            [[5, 6],
             [7, 8]]])

 

https://discuss.pytorch.kr/t/torchcat%EA%B3%BC-torchstack%EC%9D%80-%EC%96%B4%EB%96%BB%EA%B2%8C-%EB%8B%A4%EB%A5%B8%EA%B0%80%EC%9A%94/26

 

torch.cat()과 torch.stack()은 어떻게 다른가요?

공식 홈페이지와 StackOverflow 등에서 자주 보이는 질문과 답변을 번역하고 있습니다. 다음 링크에서 원문을 함께 찾아보실 수 있습니다. 원문: python - What's the difference between torch.stack() and torch.cat()

discuss.pytorch.kr

 

반응형
LIST