1. Array shape이 다른 데이터를 로드하기 위해 사용!
일반적으로 Pytorch의 Dataloader는 Dataset의 리턴을 batch 단위로 concatenate해서 리턴하도록 설계되어있어요. 그러나 각 sample의 array shape이 일치하지 않으면 오류가 발생한답니다.
array shape이 다른 대표적인 상황
- 시퀀스 길이가 제각각인 시계열 데이터
- 샘플마다 Node 수가 다른 그래프 데이터
이러한 상황에 사용할 수 있는 것이 바로 collate_fn 입니다!
2. collate_fn 이란?
collate_fn은 Dataset에서 getitem 함수로 리턴하는 샘플을 하나의 batch로 모아주는 함수입니다. Dataloader에 입력으로 collate_fn을 넣게되면 작동한답니다.
3. 예시
1) 시퀀스 길이가 다른 시계열 데이터에 padding을 추가하는 collate_fn
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, src_seq, trg_seq):
self.src_seq = src_seq
self.trg_seq = trg_seq
def __len__(self):
return len(self.src_seq)
def __getitem__(self, index):
source = torch.randn((self.src_seq[index],))
target = torch.randn((self.trg_seq[index],))
return source, target
from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
# batch는 (sequence, label)의 리스트입니다.
src_seq = [item[0] for item in batch]
trg_seq = [item[1] for item in batch]
# 시퀀스를 패딩하여 동일한 길이로 맞춰줍니다.
padded_sequences = pad_sequence(src_seq, batch_first=False, padding_value=0)
padded_targets = pad_sequence(trg_seq, batch_first=False, padding_value=0)
return padded_sequences, padded_targets
if __name__ == '__main__':
import numpy as np
src_seq_len = np.arange(10,100)
trg_seq_len = np.arange(100,10, -1)
dataset = CustomDataset(src_seq_len, trg_seq_len)
dataloader = DataLoader(dataset=dataset,
batch_size=16,
collate_fn=collate_fn)
src, trg = next(iter(dataloader))
print(src.shape)
print(trg.shape)
2) 노드 수가 다른 그래프 데이터를 list로 묶는 collate_fn
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, num_graphs, max_nodes, num_features):
self.num_graphs = num_graphs
self.max_nodes = max_nodes
self.num_features = num_features
def __len__(self):
return self.num_graphs
def __getitem__(self, index):
num_nodes = torch.randint(1, self.max_nodes + 1, (1,)).item()
graph = torch.randn((num_nodes, self.num_features))
# Create a random adjacency matrix for the graph
adj_matrix = torch.randint(0, 2, (num_nodes, num_nodes)).float()
adj_matrix = (adj_matrix + adj_matrix.t()) / 2 # Symmetrize the matrix
adj_matrix.fill_diagonal_(1) # Ensure self-loops
label = torch.randint(0, 2, (1,)) # Binary label (0 or 1)
return graph, adj_matrix, label
def collate_fn(batch):
graphs = [item[0] for item in batch]
adj_matrices = [item[1] for item in batch]
labels = [item[2] for item in batch]
return graphs, adj_matrices, labels
if __name__ == '__main__':
dataset = CustomDataset(num_graphs=100, max_nodes=30, num_features=5)
dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_fn)
graphs, adj_matrices, labels = next(iter(dataloader))
print(graphs) # node 정보
print(adj_matrices) # node간의 연결 정보
print(labels) # labels
'DeepLearning' 카테고리의 다른 글
[On-device AI] Pytorch 모델 TFLite로 변환하기(torch->onnx->tf->tflite) (4) | 2024.12.26 |
---|---|
[CV] ViT 모델 구조 정리 (2) | 2024.10.06 |
[MLLM] 구글의 오픈소스 VLM 'PaliGemma' 로컬 튜토리얼 (1) | 2024.10.02 |
[HuggingFace] OSError: 허깅페이스 403 에러를 아십니까? (2) | 2024.10.01 |