DeepLearning

[Pytorch] collate_fn: 데이터 샘플을 배치로 합치는 함수

gyeongtiger 2024. 8. 8. 22:00

1. Array shape이 다른 데이터를 로드하기 위해 사용!

일반적으로 Pytorch의 Dataloader는 Dataset의 리턴을 batch 단위로 concatenate해서 리턴하도록 설계되어있어요. 그러나 각 sample의 array shape이 일치하지 않으면 오류가 발생한답니다.

array shape이 다른 대표적인 상황
- 시퀀스 길이가 제각각인 시계열 데이터
- 샘플마다 Node 수가 다른 그래프 데이터

이러한 상황에 사용할 수 있는 것이 바로 collate_fn 입니다!

 

2. collate_fn 이란?

collate_fn은 Dataset에서 getitem 함수로 리턴하는 샘플을 하나의 batch로 모아주는 함수입니다. Dataloader에 입력으로 collate_fn을 넣게되면 작동한답니다.

 

3. 예시

1) 시퀀스 길이가 다른 시계열 데이터에 padding을 추가하는 collate_fn

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, src_seq, trg_seq):
        self.src_seq = src_seq
        self.trg_seq = trg_seq
        
    def __len__(self):
        return len(self.src_seq)
    
    def __getitem__(self, index):
        source = torch.randn((self.src_seq[index],))
        target = torch.randn((self.trg_seq[index],))
        return source, target
    

from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    # batch는 (sequence, label)의 리스트입니다.
    src_seq = [item[0] for item in batch]
    trg_seq = [item[1] for item in batch]
    
    # 시퀀스를 패딩하여 동일한 길이로 맞춰줍니다.
    padded_sequences = pad_sequence(src_seq, batch_first=False, padding_value=0)
    padded_targets = pad_sequence(trg_seq, batch_first=False, padding_value=0)
    
    return padded_sequences, padded_targets

if __name__ == '__main__':
    import numpy as np
    src_seq_len = np.arange(10,100)
    trg_seq_len = np.arange(100,10, -1)
    dataset = CustomDataset(src_seq_len, trg_seq_len)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=16,
                            collate_fn=collate_fn)
    
    src, trg = next(iter(dataloader))

    print(src.shape)
    print(trg.shape)


2) 노드 수가 다른 그래프 데이터를 list로 묶는 collate_fn

import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, num_graphs, max_nodes, num_features):
        self.num_graphs = num_graphs
        self.max_nodes = max_nodes
        self.num_features = num_features
        
    def __len__(self):
        return self.num_graphs
    
    def __getitem__(self, index):
        num_nodes = torch.randint(1, self.max_nodes + 1, (1,)).item()
        graph = torch.randn((num_nodes, self.num_features))
        
        # Create a random adjacency matrix for the graph
        adj_matrix = torch.randint(0, 2, (num_nodes, num_nodes)).float()
        adj_matrix = (adj_matrix + adj_matrix.t()) / 2  # Symmetrize the matrix
        adj_matrix.fill_diagonal_(1)  # Ensure self-loops
        
        label = torch.randint(0, 2, (1,))  # Binary label (0 or 1)
        
        return graph, adj_matrix, label

def collate_fn(batch):
    graphs = [item[0] for item in batch]
    adj_matrices = [item[1] for item in batch]
    labels = [item[2] for item in batch]
    
    return graphs, adj_matrices, labels

if __name__ == '__main__':
    dataset = CustomDataset(num_graphs=100, max_nodes=30, num_features=5)
    dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_fn)

    graphs, adj_matrices, labels = next(iter(dataloader))

    print(graphs) # node 정보
    print(adj_matrices) # node간의 연결 정보
    print(labels) # labels