Source code for libcll.datasets.cl_mnist

import torchvision
import torchvision.transforms as transforms
from PIL import Image
from libcll.datasets.cl_base_dataset import CLBaseDataset
from libcll.datasets.utils import get_transition_matrix


[docs]class CLMNIST(torchvision.datasets.MNIST, CLBaseDataset): """ Parameters ---------- root : str path to store dataset file. train : bool training set if True, else testing set. transform : callable, optional a function/transform that takes in a PIL image and returns a transformed version. target_transform : callable, optional a function/transform that takes in the target and transforms it. download : bool if true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. Attributes ---------- data : Tensor the feature of sample set. targets : Tensor the complementary labels for corresponding sample. true_targets : Tensor the ground-truth labels for corresponding sample. num_classes : int the number of classes. input_dim : int the feature space after data compressed into a 1D dimension. """ def __init__( self, root="./data/mnist", train=True, transform=transforms.ToTensor(), target_transform=None, download=True, ): super(CLMNIST, self).__init__( root, train, transform, target_transform, download ) self.num_classes = 10 self.input_dim = 1 * 28 * 28 def __getitem__(self, index): img, target = self.data[index], self.targets[index] img = Image.fromarray(img.numpy(), mode="L") if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
[docs] @classmethod def build_dataset(self, dataset_name=None, train=True, num_cl=0, transition_matrix=None, noise=None, seed=1126): if train: dataset = self(train=True) Q = get_transition_matrix(transition_matrix, dataset.num_classes, noise, seed) dataset.gen_complementary_target(num_cl, Q) else: dataset = self(train=False) return dataset