LabelRobustScoreSoft
CLASS dd_ranking.metrics.LabelRobustScoreSoft(config: Optional[Config] = None, dataset: str = 'CIFAR10', real_data_path: str = './dataset/', ipc: int = 10, model_name: str = 'ConvNet-3', soft_label_mode: str='S', soft_label_criterion: str='kl', loss_fn_kwargs: dict=None, data_aug_func: str='cutmix', aug_params: dict={'cutmix_p': 1.0}, optimizer: str='sgd', lr_scheduler: str='step', weight_decay: float=0.0005, momentum: float=0.9, step_size: int=None, num_eval: int=5, im_size: tuple=(32, 32), num_epochs: int=300, use_zca: bool=False, use_aug_for_hard: bool=False, random_data_format: str='tensor', random_data_path: str=None, real_batch_size: int=256, syn_batch_size: int=256, save_path: str=None, eval_full_data: bool=False, stu_use_torchvision: bool=False, tea_use_torchvision: bool=False, num_workers: int=4, teacher_dir: str='./teacher_models', teacher_model_names: list=None, custom_train_trans: Optional[Callable]=None, custom_val_trans: Optional[Callable]=None, device: str="cuda", dist: bool=False ) [SOURCE]
A class for evaluating the performance of a dataset distillation method with soft labels. User is able to modify the attributes as needed.
Parameters
- config(Optional[Config]): Config object for specifying all attributes. See config for more details.
- dataset(str): Name of the real dataset.
- real_data_path(str): Path to the real dataset.
- ipc(int): Images per class.
- model_name(str): Name of the surrogate model. See models for more details.
- soft_label_mode(str): Number of soft labels per image.
S
for single soft label,M
for multiple soft labels. - soft_label_criterion(str): Loss function for using soft labels. Currently supports
kl
for KL divergence,sce
for soft cross-entropy, andmse_gt
for MSEGT loss introduced in EDC. - loss_fn_kwargs(dict): Keyword arguments for the loss function.
temperature
andscale_loss
for KL and SCE loss, andmse_weight
andce_weight
for MSE and CE loss. - data_aug_func(str): Data augmentation function used during training. Currently supports
dsa
,cutmix
,mixup
. See augmentations for more details. - aug_params(dict): Parameters for the data augmentation function.
- use_aug_for_hard(bool): Whether to use the data augmentation specified in
data_aug_func
for hard label evaluation. - optimizer(str): Name of the optimizer. Currently supports torch-based optimizers -
sgd
,adam
, andadamw
. - lr_scheduler(str): Name of the learning rate scheduler. Currently supports torch-based schedulers -
step
,cosine
,lambda_step
, andcosineannealing
. - weight_decay(float): Weight decay for the optimizer.
- momentum(float): Momentum for the optimizer.
- step_size(int): Step size for the learning rate scheduler.
- use_zca(bool): Whether to use ZCA whitening.
- num_eval(int): Number of evaluations to perform.
- im_size(tuple): Size of the images.
- num_epochs(int): Number of epochs to train.
- real_batch_size(int): Batch size for the real dataset.
- syn_batch_size(int): Batch size for the synthetic dataset.
- stu_use_torchvision(bool): Whether to use torchvision to initialize the student model.
- tea_use_torchvision(bool): Whether to use torchvision to initialize the teacher model.
- teacher_dir(str): Path to the teacher model.
- teacher_model_names(list): List of teacher model names.
- random_data_format(str): Format of the random data,
tensor
orimage
. - random_data_path(str): Path to save the random data.
- eval_full_data(bool): Whether to compute the test accuracy on the full dataset (might be time-consuming on large datasets such as ImageNet1K, so we have provided a full dataset performance cache).
- num_workers(int): Number of workers for data loading.
- save_path(Optional[str]): Path to save the results.
- custom_train_trans(Optional[Callable]): Custom transformation function when loading synthetic data. Only support torchvision transformations. See torchvision-based transformations for more details.
- custom_val_trans(Optional[Callable]): Custom transformation function when loading test dataset. Only support torchvision transformations. See torchvision-based transformations for more details.
- device(str): Device to use for evaluation,
cuda
orcpu
. - dist(bool): Whether to use distributed training.
Methods
compute_metrics(image_tensor: Tensor = None, image_path: str = None, soft_labels: Tensor = None, syn_lr: float = None, lrs_lambda: float = 0.5)
- Compute the test accuracy of the surrogate model on the synthetic dataset under hard labels. We perform learning rate tuning for the best performance.
- Compute the test accuracy of the surrogate model on the real dataset under the same setting as step 1.
- Compute the test accuracy of the surrogate model on the synthetic dataset under soft labels.
- Compute the test accuracy of the surrogate model on the randomly selected dataset under the same setting as step 3.
- Compute the HLR and IOR scores.
- Compute the LRS.
The final scores are the average of the scores from num_eval
rounds.
Parameters
- image_tensor(Tensor): Image tensor. Must specify when
image_path
is not provided. We require the shape to be(N x IPC, C, H, W)
whereN
is the number of classes. - image_path(str): Path to the image. Must specify when
image_tensor
is not provided. - soft_labels(Tensor): Soft label tensor. Must specify when
soft_label_mode
isS
. The first dimension must be the same asimage_tensor
. - syn_lr(float): Learning rate for the synthetic dataset. If not specified, the learning rate will be tuned automatically.
- lrs_lambda(float): Weighting parameter for the LRS.
Returns
A dictionary with the following keys:
- hard_label_recovery_mean: Mean of HLR scores from
num_eval
rounds. - hard_label_recovery_std: Standard deviation of HLR scores from
num_eval
rounds. - improvement_over_random_mean: Mean of improvement over random scores from
num_eval
rounds. - improvement_over_random_std: Standard deviation of improvement over random scores from
num_eval
rounds. - label_robust_score_mean: Mean of LRS from
num_eval
rounds. - label_robust_score_std: Standard deviation of LRS from
num_eval
rounds.
Examples
with config file:
>>> config = Config('/path/to/config.yaml')
>>> evaluator = LabelRobustScoreSoft(config=config)
# load image and soft labels
>>> image_tensor, soft_labels = ...
# compute metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels)
# alternatively, provide image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels)
with keyword arguments:
>>> evaluator = LabelRobustScoreSoft(
... dataset='TinyImageNet',
... real_data_path='./dataset/',
... ipc=10,
... model_name='ResNet-18-BN',
... soft_label_mode='M',
... soft_label_criterion='kl',
... loss_fn_kwargs={
... "temperature": 30.0,
... "scale_loss": False,
... },
... data_aug_func='mixup',
... aug_params={
... "mixup_p": 0.8,
... },
... optimizer='sgd',
... lr_scheduler='step',
... num_epochs=300,
... step_size=100,
... weight_decay=0.0005,
... momentum=0.9,
... use_zca=False,
... use_aug_for_hard=False,
... stu_use_torchvision=True,
... tea_use_torchvision=True,
... num_workers=4,
... save_path='./results',
... eval_full_data=False,
... random_data_format='image',
... random_data_path='./random_data',
... num_eval=5,
... device='cuda'
... )
# load image and soft labels
>>> image_tensor, soft_labels = ...
# compute metrics
>>> evaluator.compute_metrics(image_tensor=image_tensor, soft_labels=soft_labels)
# alternatively, provide image path
>>> evaluator.compute_metrics(image_path='path/to/image/folder/', soft_labels=soft_labels)