Language&Framework&Etc/Pytorch

Pytorch dataset(sampler)

머리올리자 2022. 6. 21. 22:05

코드를 보다가 pytorch에서 sampler를 사용하길래 간단히 아래와 같이 코드를 짜서 output을 확인해서 어떤식으로 동작하는지 찾아보았다.

import random
import numpy as np
import torch
from torch.utils.data import Dataset, RandomSampler, BatchSampler


random_seed = 8138
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU


class default_dataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {"input" : self.data[idx], "label" : self.label[idx]}


# define variable length data
class variable_dataset(Dataset):
    def __len__(self):
        return 10

    def __getitem__(self, idx):
        return {"input":torch.tensor([idx] * (idx+1)), 
                "label": torch.tensor(idx)}


if __name__ == "__main__":
    input_data = torch.tensor([[i, i+1, i+2] for i in range(0, 10)])
    label_data = torch.tensor([i for i in range(0, 10)]) 

    dataset = default_dataset(input_data, label_data)

    """DEFAULT SETTING"""
    dataloader = torch.utils.data.DataLoader(dataset) # batch O
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4) # batch X
   
    
    """RANDOM SAMPLER""" 
    random_sampler = RandomSampler(dataset) 
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, sampler=random_sampler)
    dataloader2 =torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)


    """BATCH SAMPLER"""
    random_sampler = RandomSampler(dataset)
    batch_sampler = BatchSampler(random_sampler, batch_size = 3, drop_last=False) #include batch_size, shuffle, drop_last
    dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler) 

    # for data in dataloader:
    #     print(data["input"], data["label"])

    """COLLATE FUNCTION""" # -> use when dataset is variable length 
    # var_dataset = variable_dataset()
    
    # dataloader = torch.utils.data.DataLoader(var_dataset)
    # for data in dataloader:
    #     print(data['input'])

    # # ERROR
    # dataloader = torch.utils.data.DataLoader(var_dataset, batch_size=2)
    # for data in dataloader:
    #     print(data['input'].shape, data['label'])