SoftLabelEvaluator
CLASS dd_ranking.metrics.SoftLabelEvaluator(config: Optional[Config] = None, dataset: str = 'CIFAR10', real_data_path: str = './dataset/', ipc: int = 10, model_name: str = 'ConvNet-3', soft_label_mode: str='S', soft_label_criterion: str='kl', temperature: float=1.0, data_aug_func: str='cutmix', aug_params: dict={'cutmix_p': 1.0}, optimizer: str='sgd', lr_scheduler: str='step', weight_decay: float=0.0005, momentum: float=0.9, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False, real_batch_size: int=256, syn_batch_size: int=256, default_lr: float=0.01, save_path: str=None, stu_use_torchvision: bool=False, tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models', custom_train_trans: Optional[Callable]=None, custom_val_trans: Optional[Callable]=None, device: str="cuda" ) [SOURCE]
A class for evaluating the performance of a dataset distillation method with soft labels. User is able to modify the attributes as needed.
Parameters
- config(Optional[Config]): Config object for specifying all attributes. See config for more details.
- dataset(str): Name of the real dataset.
- real_data_path(str): Path to the real dataset.
- ipc(int): Images per class.
- model_name(str): Name of the surrogate model. See models for more details.
- soft_label_mode(str): Number of soft labels per image.
S
for single soft label,M
for multiple soft labels. - soft_label_criterion(str): Loss function for using soft labels. Currently supports
kl
for KL divergence,sce
for soft cross-entropy. - temperature(float): Temperature for knowledge distillation.
- data_aug_func(str): Data augmentation function used during training. Currently supports
dsa
,cutmix
,mixup
. See augmentations for more details. - aug_params(dict): Parameters for the data augmentation function.
- use_aug_for_hard(bool): Whether to use the data augmentation specified in
data_aug_func
for hard label evaluation. - optimizer(str): Name of the optimizer. Currently supports torch-based optimizers -
sgd
,adam
, andadamw
. - lr_scheduler(str): Name of the learning rate scheduler. Currently supports torch-based schedulers -
step
,cosine
,lambda_step
, andlambda_cos
. - weight_decay(float): Weight decay for the optimizer.
- momentum(float): Momentum for the optimizer.
- use_zca(bool): Whether to use ZCA whitening.
- num_eval(int): Number of evaluations to perform.
- im_size(tuple): Size of the images.
- num_epochs(int): Number of epochs to train.
- real_batch_size(int): Batch size for the real dataset.
- syn_batch_size(int): Batch size for the synthetic dataset.
- stu_use_torchvision(bool): Whether to use torchvision to initialize the student model.
- tea_use_torchvision(bool): Whether to use torchvision to initialize the teacher model.
- teacher_dir(str): Path to the teacher model.
- default_lr(float): Default learning rate for the optimizer, typically used for training on the real dataset.
- num_workers(int): Number of workers for data loading.
- save_path(Optional[str]): Path to save the results.
- custom_train_trans(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See torchvision-based transformations for more details.
- custom_val_trans(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See torchvision-based transformations for more details.
- device(str): Device to use for evaluation,
cuda
orcpu
.
Methods
compute_metrics(image_tensor: Tensor = None, image_path: str = None, soft_labels: Tensor = None, syn_lr: float = None)
- Compute the test accuracy of the surrogate model on the synthetic dataset under hard labels. We perform learning rate tuning for the best performance.
- Compute the test accuracy of the surrogate model on the real dataset under the same setting as step 1.
- Compute the test accuracy of the surrogate model on the synthetic dataset under soft labels.
- Compute the test accuracy of the surrogate model on the randomly selected dataset under the same setting as step 3.
- Compute the HLR and IOR scores.
The final scores are the average of the scores from num_eval
rounds.
Parameters
- image_tensor(Tensor): Image tensor. Must specify when
image_path
is not provided. We require the shape to be(N x IPC, C, H, W)
whereN
is the number of classes. - image_path(str): Path to the image. Must specify when
image_tensor
is not provided. - soft_labels(Tensor): Soft label tensor. Must specify when
soft_label_mode
isS
. The first dimension must be the same asimage_tensor
. - syn_lr(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically.
Returns
A dictionary with the following keys:
- hard_label_recovery_mean: Mean of HLR scores from
num_eval
rounds. - hard_label_recovery_std: Standard deviation of HLR scores from
num_eval
rounds. - improvement_over_random_mean: Mean of improvement over random scores from
num_eval
rounds. - improvement_over_random_std: Standard deviation of improvement over random scores from
num_eval
rounds.
Examples
with config file:
>>> config = Config('/path/to/config.yaml')
>>> evaluator = SoftLabelEvaluator(config=config)
# load image and soft labels
>>> image_tensor, soft_labels = ...
# compute metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels)
# alternatively, provide image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels)
with keyword arguments:
>>> evaluator = SoftLabelEvaluator(
... dataset='TinyImageNet',
... model_name='ResNet-18-BN',
... soft_label_mode='M',
... soft_label_criterion='kl',
... temperature=10.0,
... data_aug_func='mixup',
... aug_params={
... "mixup_p": 0.8,
... },
... optimizer='sgd',
... lr_scheduler='step',
... weight_decay=0.0005,
... momentum=0.9,
... stu_use_torchvision=True,
... tea_use_torchvision=True,
... num_eval=5,
... device='cuda'
... )
# load image and soft labels
>>> image_tensor, soft_labels = ...
# compute metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels)
# alternatively, provide image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels)