import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import pickle
import gdown
import os
from libcll.datasets.cl_base_dataset import CLBaseDataset
from libcll.datasets.utils import get_transition_matrix
[docs]class CLCIFAR10(torchvision.datasets.CIFAR10, CLBaseDataset):
"""
Real-world complementary-label dataset. Call ``gen_complementary_target()`` if you want to access synthetic complementary labels.
Parameters
----------
root : str
path to store dataset file.
train : bool
training set if True, else testing set.
transform : callable, optional
a function/transform that takes in a PIL image and returns a transformed version.
target_transform : callable, optional
a function/transform that takes in the target and transforms it.
download : bool
if true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
num_cl : int
the number of real-world complementary labels of each data chosen from [1, 3].
Attributes
----------
data : Tensor
the feature of sample set.
targets : Tensor
the complementary labels for corresponding sample.
true_targets : Tensor
the ground-truth labels for corresponding sample.
num_classes : int
the number of classes.
input_dim : int
the feature space after data compressed into a 1D dimension.
"""
def __init__(
self,
root="./data/cifar10",
train=True,
transform=None,
target_transform=None,
download=True,
num_cl=1,
):
if train:
dataset_path = f"{root}/clcifar10.pkl"
if download and not os.path.exists(dataset_path):
os.makedirs(root, exist_ok=True)
gdown.download(
id="1uNLqmRUkHzZGiSsCtV2-fHoDbtKPnVt2", output=dataset_path
)
with open(dataset_path, "rb") as f:
data = pickle.load(f)
self.data = data["images"]
self.true_targets = torch.Tensor(data["ord_labels"]).view(-1)
self.targets = torch.Tensor(data["cl_labels"])[:, :num_cl]
self.transform = transform
self.target_transform = target_transform
else:
super(CLCIFAR10, self).__init__(
root, train, transform, target_transform, download
)
self.targets = torch.Tensor(self.targets)
self.num_classes = 10
self.input_dim = 3 * 32 * 32
[docs] @classmethod
def build_dataset(self, dataset_name=None, train=True, num_cl=0, transition_matrix=None, noise=None, seed=1126):
if train:
train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(
[0.4914, 0.4822, 0.4465], [0.247, 0.2435, 0.2616]
),
]
)
dataset = self(
train=True,
transform=train_transform,
num_cl=num_cl,
)
if dataset_name == "cifar10":
Q = get_transition_matrix(transition_matrix, dataset.num_classes, noise, seed)
dataset.gen_complementary_target(num_cl, Q)
else:
test_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.2435, 0.2616]),
]
)
dataset = self(
train=False,
transform=test_transform,
num_cl=num_cl,
)
return dataset