HardLabelEvaluator
CLASS dd_ranking.metrics.HardLabelEvaluator(config=None, dataset: str = 'CIFAR10', real_data_path: str = './dataset/', ipc: int = 10, model_name: str = 'ConvNet-3', data_aug_func: str = 'cutmix', aug_params: dict = {'cutmix_p': 1.0}, optimizer: str = 'sgd', lr_scheduler: str = 'step', weight_decay: float = 0.0005, momentum: float = 0.9, use_zca: bool = False, num_eval: int = 5, im_size: tuple = (32, 32), num_epochs: int = 300, real_batch_size: int = 256, syn_batch_size: int = 256, use_torchvision: bool = False, default_lr: float = 0.01, num_workers: int = 4, save_path: Optional[str] = None, custom_train_trans = None, custom_val_trans = None, device: str = "cuda" ) [SOURCE]
A class for evaluating the performance of a dataset distillation method with hard labels. User is able to modify the attributes as needed.
Parameters
- config(Optional[Config]): Config object for specifying all attributes. See config for more details.
- dataset(str): Name of the real dataset.
- real_data_path(str): Path to the real dataset.
- ipc(int): Images per class.
- model_name(str): Name of the surrogate model. See models for more details.
- data_aug_func(str): Data augmentation function used during training. Currently supports
dsa
,cutmix
,mixup
. See augmentations for more details. - aug_params(dict): Parameters for the data augmentation function.
- optimizer(str): Name of the optimizer. Currently supports torch-based optimizers -
sgd
,adam
, andadamw
. - lr_scheduler(str): Name of the learning rate scheduler. Currently supports torch-based schedulers -
step
,cosine
,lambda_step
, andlambda_cos
. - weight_decay(float): Weight decay for the optimizer.
- momentum(float): Momentum for the optimizer.
- use_zca(bool): Whether to use ZCA whitening.
- num_eval(int): Number of evaluations to perform.
- im_size(tuple): Size of the images.
- num_epochs(int): Number of epochs to train.
- real_batch_size(int): Batch size for the real dataset.
- syn_batch_size(int): Batch size for the synthetic dataset.
- use_torchvision(bool): Whether to use torchvision to initialize the model.
- default_lr(float): Default learning rate for the optimizer, typically used for training on the real dataset.
- num_workers(int): Number of workers for data loading.
- save_path(Optional[str]): Path to save the results.
- custom_train_trans(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See torchvision-based transformations for more details.
- custom_val_trans(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See torchvision-based transformations for more details.
- device(str): Device to use for evaluation,
cuda
orcpu
.
Methods
compute_metrics(image_tensor: Tensor = None, image_path: str = None, hard_labels: Tensor = None, syn_lr: float = None)
This method computes the HLR, IOR, and DD-Ranking scores for the given image and soft labels (if provided). In each evaluation round, we set a different random seed and perform the following steps:
- Compute the test accuracy of the surrogate model on the synthetic dataset under hard labels. We tune the learning rate for the best performance if
syn_lr
is not provided. - Compute the test accuracy of the surrogate model on the real dataset under the same setting as step 1.
- Compute the test accuracy of the surrogate model on the randomly selected dataset under the same setting as step 1.
- Compute the HLR and IOR scores.
The final scores are the average of the scores from num_eval
rounds.
Parameters
- image_tensor(Tensor): Image tensor. Must specify when
image_path
is not provided. We require the shape to be(N x IPC, C, H, W)
whereN
is the number of classes. - image_path(str): Path to the image. Must specify when
image_tensor
is not provided. - hard_labels(Tensor): Hard label tensor. The first dimension must be the same as
image_tensor
. - syn_lr(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically.
Returns
A dictionary with the following keys:
- hard_label_recovery_mean: Mean of HLR scores from
num_eval
rounds. - hard_label_recovery_std: Standard deviation of HLR scores from
num_eval
rounds. - improvement_over_random_mean: Mean of improvement over random scores from
num_eval
rounds. - improvement_over_random_std: Standard deviation of improvement over random scores from
num_eval
rounds.
Examples:
with config file:
>>> config = Config('/path/to/config.yaml')
>>> evaluator = HardLabelEvaluator(config=config)
# load the image and hard labels
>>> image_tensor, hard_labels = ...
# compute the metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, hard_labels=hard_labels)
# alternatively, you can provide the image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', hard_labels=hard_labels)
with keyword arguments:
>>> evaluator = HardLabelEvaluator(
... dataset='CIFAR10',
... model_name='ConvNet-3',
... data_aug_func='dsa',
... aug_params={
... "prob_flip": 0.5,
... "ratio_rotate": 15.0,
... "saturation": 2.0,
... "brightness": 1.0,
... "contrast": 0.5,
... "ratio_scale": 1.2,
... "ratio_crop_pad": 0.125,
... "ratio_cutout": 0.5
... },
... optimizer='sgd',
... lr_scheduler='step',
... weight_decay=0.0005,
... momentum=0.9,
... use_zca=False,
... num_eval=5,
... device='cuda'
... )
# load the image and hard labels
>>> image_tensor, hard_labels = ...
# compute the metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, hard_labels=hard_labels)
# alternatively, you can provide the image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', hard_labels=hard_labels)