Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

LabelRobustScoreSoft

CLASS dd_ranking.metrics.LabelRobustScoreSoft(config: Optional[Config] = None, dataset: str = 'CIFAR10', real_data_path: str = './dataset/', ipc: int = 10, model_name: str = 'ConvNet-3', soft_label_mode: str='S', soft_label_criterion: str='kl', loss_fn_kwargs: dict=None, data_aug_func: str='cutmix', aug_params: dict={'cutmix_p': 1.0}, optimizer: str='sgd', lr_scheduler: str='step', weight_decay: float=0.0005, momentum: float=0.9, step_size: int=None, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False, use_aug_for_hard: bool=False, random_data_format: str='tensor', random_data_path: str=None, real_batch_size: int=256, syn_batch_size: int=256, save_path: str=None, eval_full_data: bool=False, stu_use_torchvision: bool=False, tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models', teacher_model_names: list=None, custom_train_trans: Optional[Callable]=None, custom_val_trans: Optional[Callable]=None, device: str="cuda", dist: bool=False ) [SOURCE]

A class for evaluating the performance of a dataset distillation method with soft labels. User is able to modify the attributes as needed.

Parameters

  • config(Optional[Config]): Config object for specifying all attributes. See config for more details.
  • dataset(str): Name of the real dataset.
  • real_data_path(str): Path to the real dataset.
  • ipc(int): Images per class.
  • model_name(str): Name of the surrogate model. See models for more details.
  • soft_label_mode(str): Number of soft labels per image. S for single soft label, M for multiple soft labels.
  • soft_label_criterion(str): Loss function for using soft labels. Currently supports kl for KL divergence, sce for soft cross-entropy, and mse_gt for MSEGT loss introduced in EDC.
  • loss_fn_kwargs(dict): Keyword arguments for the loss function. temperature and scale_loss for KL and SCE loss, and mse_weight and ce_weight for MSE and CE loss.
  • data_aug_func(str): Data augmentation function used during training. Currently supports dsa, cutmix, mixup. See augmentations for more details.
  • aug_params(dict): Parameters for the data augmentation function.
  • use_aug_for_hard(bool): Whether to use the data augmentation specified in data_aug_func for hard label evaluation.
  • optimizer(str): Name of the optimizer. Currently supports torch-based optimizers - sgd, adam, and adamw.
  • lr_scheduler(str): Name of the learning rate scheduler. Currently supports torch-based schedulers - step, cosine, lambda_step, and cosineannealing.
  • weight_decay(float): Weight decay for the optimizer.
  • momentum(float): Momentum for the optimizer.
  • step_size(int): Step size for the learning rate scheduler.
  • use_zca(bool): Whether to use ZCA whitening.
  • num_eval(int): Number of evaluations to perform.
  • im_size(tuple): Size of the images.
  • num_epochs(int): Number of epochs to train.
  • real_batch_size(int): Batch size for the real dataset.
  • syn_batch_size(int): Batch size for the synthetic dataset.
  • stu_use_torchvision(bool): Whether to use torchvision to initialize the student model.
  • tea_use_torchvision(bool): Whether to use torchvision to initialize the teacher model.
  • teacher_dir(str): Path to the teacher model.
  • teacher_model_names(list): List of teacher model names.
  • random_data_format(str): Format of the random data, tensor or image.
  • random_data_path(str): Path to save the random data.
  • eval_full_data(bool): Whether to compute the test accuracy on the full dataset (might be time-consuming on large datasets such as ImageNet1K, so we have provided a full dataset performance cache).
  • num_workers(int): Number of workers for data loading.
  • save_path(Optional[str]): Path to save the results.
  • custom_train_trans(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See torchvision-based transformations for more details.
  • custom_val_trans(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See torchvision-based transformations for more details.
  • device(str): Device to use for evaluation, cuda or cpu.
  • dist(bool): Whether to use distributed training.

Methods

compute_metrics(image_tensor: Tensor = None, image_path: str = None, soft_labels: Tensor = None, syn_lr: float = None, lrs_lambda: float = 0.5)

This method computes the HLR, IOR, and LRS for the given image and soft labels (if provided). In each evaluation round, we set a different random seed and perform the following steps:
  1. Compute the test accuracy of the surrogate model on the synthetic dataset under hard labels. We perform learning rate tuning for the best performance.
  2. Compute the test accuracy of the surrogate model on the real dataset under the same setting as step 1.
  3. Compute the test accuracy of the surrogate model on the synthetic dataset under soft labels.
  4. Compute the test accuracy of the surrogate model on the randomly selected dataset under the same setting as step 3.
  5. Compute the HLR and IOR scores.
  6. Compute the LRS.

The final scores are the average of the scores from num_eval rounds.

Parameters

  • image_tensor(Tensor): Image tensor. Must specify when image_path is not provided. We require the shape to be (N x IPC, C, H, W) where N is the number of classes.
  • image_path(str): Path to the image. Must specify when image_tensor is not provided.
  • soft_labels(Tensor): Soft label tensor. Must specify when soft_label_mode is S. The first dimension must be the same as image_tensor.
  • syn_lr(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically.
  • lrs_lambda(float): Weighting parameter for the LRS.

Returns

A dictionary with the following keys:

  • hard_label_recovery_mean: Mean of HLR scores from num_eval rounds.
  • hard_label_recovery_std: Standard deviation of HLR scores from num_eval rounds.
  • improvement_over_random_mean: Mean of improvement over random scores from num_eval rounds.
  • improvement_over_random_std: Standard deviation of improvement over random scores from num_eval rounds.
  • label_robust_score_mean: Mean of LRS from num_eval rounds.
  • label_robust_score_std: Standard deviation of LRS from num_eval rounds.

Examples

with config file:

>>> config = Config('/path/to/config.yaml')
>>> evaluator = LabelRobustScoreSoft(config=config)
# load image and soft labels
>>> image_tensor, soft_labels = ... 
# compute metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels)
# alternatively, provide image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels) 

with keyword arguments:

>>> evaluator = LabelRobustScoreSoft(
...     dataset='TinyImageNet',
...     real_data_path='./dataset/',
...     ipc=10,
...     model_name='ResNet-18-BN',
...     soft_label_mode='M',
...     soft_label_criterion='kl',
...     loss_fn_kwargs={
...         "temperature": 30.0,
...         "scale_loss": False,
...     },
...     data_aug_func='mixup',
...     aug_params={
...         "mixup_p": 0.8,
...     },
...     optimizer='sgd',
...     lr_scheduler='step',
...     num_epochs=300,
...     step_size=100,
...     weight_decay=0.0005,
...     momentum=0.9,
...     use_zca=False,
...     use_aug_for_hard=False,
...     stu_use_torchvision=True,
...     tea_use_torchvision=True,
...     num_workers=4,
...     save_path='./results',
...     eval_full_data=False,
...     random_data_format='image',
...     random_data_path='./random_data',
...     num_eval=5,
...     device='cuda'
... )
# load image and soft labels
>>> image_tensor, soft_labels = ... 
# compute metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels)
# alternatively, provide image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels)