From 046b1f668410b22071706f0f903da23f7cc7e1da Mon Sep 17 00:00:00 2001 From: dashu233 Date: Tue, 15 Jun 2021 11:16:15 +0800 Subject: [PATCH 1/3] alg update --- domainbed/algorithms.py | 617 ++++++++++++++++++++++++++-------------- 1 file changed, 398 insertions(+), 219 deletions(-) diff --git a/domainbed/algorithms.py b/domainbed/algorithms.py index 0d45facf6..0b63f29b4 100644 --- a/domainbed/algorithms.py +++ b/domainbed/algorithms.py @@ -4,23 +4,24 @@ import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd -from torch.autograd import Variable import copy import numpy as np -from collections import defaultdict from domainbed import networks -from domainbed.lib.misc import random_pairs_of_minibatches, ParamDict +from domainbed.lib.misc import random_pairs_of_minibatches +from domainbed.matrix_opt_for_train import opt_kde ALGORITHMS = [ 'ERM', - 'Fish', + 'RegularERM', + 'AblationRERM', 'IRM', 'GroupDRO', 'Mixup', 'MLDG', 'CORAL', + 'RegularCORAL' 'MMD', 'DANN', 'CDANN', @@ -30,9 +31,7 @@ 'VREx', 'RSC', 'SD', - 'ANDMask', - 'IGA', - 'SelfReg' + 'MyRSC' ] def get_algorithm_class(algorithm_name): @@ -105,65 +104,189 @@ def predict(self, x): def feature(self, x): return self.featurizer(x) -class Fish(Algorithm): +class CutERM(ERM): + def __init__(self, input_shape, num_classes, num_domains, hparams): + super(CutERM, self).__init__(input_shape, num_classes, num_domains, + hparams) + self.label_num = num_classes + self.feature_num = self.featurizer.n_outputs + self.env_num = num_domains + self.update_cal = 0 + + def update(self, minibatches, unlabeled=None): + self.update_cal += 1 # only difference with ERM + all_x = torch.cat([x for x, y in minibatches]) + all_y = torch.cat([y for x, y in minibatches]) + loss = F.cross_entropy(self.predict(all_x), all_y) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return {'loss': loss.item()} + def update_classifer(self,cut_param:torch.Tensor): + cut_param.detach_() + cut_param.requires_grad = False + self.classifier.weight.data = torch.matmul(self.classifier.weight.data,cut_param) + + +class RegularERM(ERM): + def __init__(self, input_shape, num_classes, num_domains, hparams): + super(RegularERM, self).__init__(input_shape, num_classes, num_domains, + hparams) + + self.use_erm = False + if 'use_erm' in hparams: + self.use_erm = hparams['use_erm'] + self.label_num = num_classes + self.feature_num = self.featurizer.n_outputs + self.env_num = num_domains + self.gamma = 0 + self.mean_method = 'full_decay' + print('mean_method:',self.mean_method) + if 'mean_method' in hparams: + mth = hparams['mean_method'] + assert mth in ['decay_per_step','decay_per_data','full_decay'], 'make sure mean_method right' + self.mean_method = mth + + if 'RERM_gamma' in hparams: + self.gamma = hparams['RERM_gamma'] + self.lam = 0.0 + if 'RERM_lam' in hparams: + self.lam = hparams['RERM_lam'] + if self.lam < 1e-10: + self.use_erm = True + if not self.use_erm: + self.register_buffer('mean_table', + torch.zeros([self.label_num, self.env_num, self.feature_num])) + self.register_buffer('visit_table', + torch.zeros([self.label_num, self.env_num])) + # visit_table will be used in a different way in full_decay method + + + def update(self, minibatches, unlabeled=None): + if self.use_erm: + return super().update(minibatches) + + # calculate mean + # only the newest data will cause gradient + if self.mean_method == 'decay_per_data': + self.mean_table.detach_() + for env_id, (x, y) in enumerate(minibatches): + feats = self.feature(x) + + # sweep all label + for i in range(self.label_num): + fetch_ids = torch.where(y == i)[0] + if len(fetch_ids): + # if there is label i, average them + tmp = torch.mean(feats[fetch_ids, :], dim=0) + if self.visit_table[i, env_id] > 0.01: + # if it's not the fist time, do decay + self.mean_table[i, env_id, :] = (1 - self.gamma) * self.mean_table[i, env_id, :] \ + + self.gamma * tmp + else: + # otherwise, no decay, set visit to 1 + self.mean_table[i, env_id, :] = tmp + self.visit_table[i, env_id] = 1.0 + # else: + # self.mean_table[y,env_id,:] += self.gamma*feats + + loss1 = self.lam * torch.mean(torch.var(self.mean_table, dim=1)) + all_x = torch.cat([x for x, y in minibatches]) + all_y = torch.cat([y for x, y in minibatches]) + loss2 = F.cross_entropy(self.predict(all_x), all_y) + self.optimizer.zero_grad() + loss = loss1 + loss2 + loss.backward() + self.optimizer.step() + + return {'loss': loss.item(), + 'loss_acc': loss2.item(), + 'loss_mean': loss1.item(), + } + if self.mean_method == 'full_decay': + x0,_ = minibatches[0] + visit_table = x0.new_zeros([self.label_num,self.env_num]) + mean_table = x0.new_zeros([self.label_num,self.env_num,self.feature_num]) + for env_id, (x, y) in enumerate(minibatches): + feats = self.feature(x) + for i in range(self.label_num): + fetch_ids = torch.where(y == i)[0] + if len(fetch_ids): + # if there is label i, average them + visit_table[i,env_id] = 1.0 + mean_table[i,env_id,:] = torch.mean(feats[fetch_ids, :], dim=0) + mean_loss = x0.new_zeros([1]) + nzlabel = 0.0 + for i in range(self.label_num): + # TODO: vars sum or difference sum ? + fetch_envs = torch.where(visit_table[i,:]>0.01)[0] + if len(fetch_envs)>1: + mean_loss += torch.mean(torch.var(mean_table[i,fetch_envs,:],dim=1)) + nzlabel += 1.0 + mean_loss = self.lam * mean_loss / nzlabel + + all_x = torch.cat([x for x, y in minibatches]) + all_y = torch.cat([y for x, y in minibatches]) + loss2 = F.cross_entropy(self.predict(all_x), all_y) + loss = mean_loss + loss2 + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + return {'loss': loss.item(), + 'loss_acc': loss2.item(), + 'loss_mean': mean_loss.item(), + } + + def feature_predict(self,x): + return self.classifier(x) + + +class AblationRERM(Algorithm): """ - Implementation of Fish, as seen in Gradient Matching for Domain - Generalization, Shi et al. 2021. + Empirical Risk Minimization (ERM) """ def __init__(self, input_shape, num_classes, num_domains, hparams): - super(Fish, self).__init__(input_shape, num_classes, num_domains, - hparams) - self.input_shape = input_shape - self.num_classes = num_classes - - self.network = networks.WholeFish(input_shape, num_classes, hparams) + super(AblationRERM, self).__init__(input_shape, num_classes, num_domains, + hparams) + self.featurizer = networks.Featurizer(input_shape, self.hparams) + self.classifier = networks.Classifier( + self.featurizer.n_outputs, + num_classes, + self.hparams['nonlinear_classifier']) + # TODO: find those para in dataset or args + self.label_num = 75 + self.feature_num = self.featurizer.n_outputs + env_num = 3 + self.layer_norm = nn.LayerNorm([self.feature_num]) + self.featurizer = nn.Sequential(self.featurizer,self.layer_norm) + self.network = nn.Sequential(self.featurizer,self.classifier) self.optimizer = torch.optim.Adam( self.network.parameters(), lr=self.hparams["lr"], weight_decay=self.hparams['weight_decay'] ) - self.optimizer_inner_state = None - - def create_clone(self, device): - self.network_inner = networks.WholeFish(self.input_shape, self.num_classes, self.hparams, - weights=self.network.state_dict()).to(device) - self.optimizer_inner = torch.optim.Adam( - self.network_inner.parameters(), - lr=self.hparams["lr"], - weight_decay=self.hparams['weight_decay'] - ) - if self.optimizer_inner_state is not None: - self.optimizer_inner.load_state_dict(self.optimizer_inner_state) - - def fish(self, meta_weights, inner_weights, lr_meta): - meta_weights = ParamDict(meta_weights) - inner_weights = ParamDict(inner_weights) - meta_weights += lr_meta * (inner_weights - meta_weights) - return meta_weights def update(self, minibatches, unlabeled=None): - self.create_clone(minibatches[0][0].device) - - for x, y in minibatches: - loss = F.cross_entropy(self.network_inner(x), y) - self.optimizer_inner.zero_grad() - loss.backward() - self.optimizer_inner.step() + all_x = torch.cat([x for x, y in minibatches]) + all_y = torch.cat([y for x, y in minibatches]) + loss = F.cross_entropy(self.predict(all_x), all_y) - self.optimizer_inner_state = self.optimizer_inner.state_dict() - meta_weights = self.fish( - meta_weights=self.network.state_dict(), - inner_weights=self.network_inner.state_dict(), - lr_meta=self.hparams["meta_lr"] - ) - self.network.reset_weights(meta_weights) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() return {'loss': loss.item()} def predict(self, x): return self.network(x) + # 提取特征 + def feature(self, x): + return self.featurizer(x) + class ARM(ERM): """ Adaptive Risk Minimization (ARM) """ @@ -427,6 +550,39 @@ def update(self, minibatches, unlabeled=None): return {'loss': objective.item()} +class RegularMixup(Mixup): + class Mixup(ERM): + """ + Mixup of minibatches from different domains + https://arxiv.org/pdf/2001.00677.pdf + https://arxiv.org/pdf/1912.01805.pdf + """ + + def __init__(self, input_shape, num_classes, num_domains, hparams): + super(Mixup, self).__init__(input_shape, num_classes, num_domains, + hparams) + + def update(self, minibatches, unlabeled=None): + objective = 0 + + for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches): + lam = np.random.beta(self.hparams["mixup_alpha"], + self.hparams["mixup_alpha"]) + + x = lam * xi + (1 - lam) * xj + predictions = self.predict(x) + + objective += lam * F.cross_entropy(predictions, yi) + objective += (1 - lam) * F.cross_entropy(predictions, yj) + + objective /= len(minibatches) + + self.optimizer.zero_grad() + objective.backward() + self.optimizer.step() + + return {'loss': objective.item()} + class GroupDRO(ERM): """ @@ -619,6 +775,7 @@ def mmd(self, x, y): mean_y = y.mean(0, keepdim=True) cent_x = x - mean_x cent_y = y - mean_y + cent_y = y - mean_y cova_x = (cent_x.t() @ cent_x) / (len(x) - 1) cova_y = (cent_y.t() @ cent_y) / (len(y) - 1) @@ -674,6 +831,75 @@ def __init__(self, input_shape, num_classes, num_domains, hparams): super(CORAL, self).__init__(input_shape, num_classes, num_domains, hparams, gaussian=False) +class RegularCORAL(CORAL): + def __init__(self, input_shape, num_classes, num_domains, hparams): + super(RegularCORAL, self).__init__(input_shape, num_classes, + num_domains, hparams,) + self.use_erm = False + if 'use_erm' in hparams: + self.use_erm = hparams['use_erm'] + self.label_num = num_classes + self.feature_num = self.featurizer.n_outputs + self.env_num = num_domains + self.lam = 0.0 + if 'regular_lam' in hparams: + self.lam = hparams['regular_lam'] + if self.lam < 1e-10: + self.use_erm = True + if not self.use_erm: + self.register_buffer('mean_table', + torch.zeros([self.label_num, self.env_num, self.feature_num])) + self.register_buffer('visit_table', + torch.zeros([self.label_num, self.env_num])) + # visit_table will be used in a different way in full_decay method + + def update(self, minibatches, unlabeled=None): + objective = 0 + penalty = 0 + nmb = len(minibatches) + + features = [self.featurizer(xi) for xi, _ in minibatches] + classifs = [self.classifier(fi) for fi in features] + targets = [yi for _, yi in minibatches] + + for i in range(nmb): + objective += F.cross_entropy(classifs[i], targets[i]) + for j in range(i + 1, nmb): + penalty += self.mmd(features[i], features[j]) + + objective /= nmb + if nmb > 1: + penalty /= (nmb * (nmb - 1) / 2) + + x0, _ = minibatches[0] + visit_table = x0.new_zeros([self.label_num, self.env_num]) + mean_table = x0.new_zeros([self.label_num, self.env_num, self.feature_num]) + for env_id, feats in enumerate(features): + for i in range(self.label_num): + fetch_ids = torch.where(targets[env_id] == i)[0] + if len(fetch_ids): + visit_table[i, env_id] = 1.0 + mean_table[i, env_id, :] = torch.mean(feats[fetch_ids, :], dim=0) + mean_loss = x0.new_zeros([1]) + nzlabel = 0.0 + for i in range(self.label_num): + # TODO: vars sum or difference sum ? + fetch_envs = torch.where(visit_table[i, :] > 0.01)[0] + if len(fetch_envs) > 1: + mean_loss += torch.mean(torch.var(mean_table[i, fetch_envs, :], dim=1)) + nzlabel += 1.0 + mean_loss = self.lam * mean_loss / nzlabel + + self.optimizer.zero_grad() + (objective + (self.hparams['mmd_gamma']*penalty) + mean_loss).backward() + self.optimizer.step() + + if torch.is_tensor(penalty): + penalty = penalty.item() + + return {'loss': objective.item(), 'penalty': penalty,'loss_mean':mean_loss.item()} + + class MTL(Algorithm): """ @@ -856,6 +1082,7 @@ class RSC(ERM): def __init__(self, input_shape, num_classes, num_domains, hparams): super(RSC, self).__init__(input_shape, num_classes, num_domains, hparams) + self.network self.drop_f = (1 - hparams['rsc_f_drop_factor']) * 100 self.drop_b = (1 - hparams['rsc_b_drop_factor']) * 100 self.num_classes = num_classes @@ -934,186 +1161,138 @@ def update(self, minibatches, unlabeled=None): return {'loss': loss.item(), 'penalty': penalty.item()} -class ANDMask(ERM): - """ - Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329] - AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary] - """ - - def __init__(self, input_shape, num_classes, num_domains, hparams): - super(ANDMask, self).__init__(input_shape, num_classes, num_domains, hparams) - - self.tau = hparams["tau"] - - def update(self, minibatches, unlabeled=None): - - total_loss = 0 - param_gradients = [[] for _ in self.network.parameters()] - all_x = torch.cat([x for x,y in minibatches]) - all_logits = self.network(all_x) - all_logits_idx = 0 - for i, (x, y) in enumerate(minibatches): - logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]] - all_logits_idx += x.shape[0] - - env_loss = F.cross_entropy(logits, y) - total_loss += env_loss - - env_grads = autograd.grad(env_loss, self.network.parameters(), retain_graph=True) - for grads, env_grad in zip(param_gradients, env_grads): - grads.append(env_grad) - - mean_loss = total_loss / len(minibatches) - - self.optimizer.zero_grad() - self.mask_grads(self.tau, param_gradients, self.network.parameters()) - self.optimizer.step() - - return {'loss': mean_loss.item()} - - def mask_grads(self, tau, gradients, params): - - for param, grads in zip(params, gradients): - grads = torch.stack(grads, dim=0) - grad_signs = torch.sign(grads) - mask = torch.mean(grad_signs, dim=0).abs() >= self.tau - mask = mask.to(torch.float32) - avg_grad = torch.mean(grads, dim=0) - - mask_t = (mask.sum() / mask.numel()) - param.grad = mask * avg_grad - param.grad *= (1. / (1e-10 + mask_t)) - - return 0 - -class IGA(ERM): - """ - Inter-environmental Gradient Alignment - From https://arxiv.org/abs/2008.01883v2 - """ - - def __init__(self, in_features, num_classes, num_domains, hparams): - super(IGA, self).__init__(in_features, num_classes, num_domains, hparams) - - def update(self, minibatches, unlabeled=False): - - all_x = torch.cat([x for x,y in minibatches]) - all_logits = self.network(all_x) - - total_loss = 0 - all_logits_idx = 0 - grads = [] - for i, (x, y) in enumerate(minibatches): - logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]] - all_logits_idx += x.shape[0] - - env_loss = F.cross_entropy(logits, y) - total_loss += env_loss - - grads.append( autograd.grad(env_loss, self.network.parameters(), retain_graph=True) ) - - mean_loss = total_loss / len(minibatches) - mean_grad = autograd.grad(mean_loss, self.network.parameters(), retain_graph=True) - - # compute trace penalty - penalty_value = 0 - for grad in grads: - for g, mean_g in zip(grad, mean_grad): - penalty_value += (g - mean_g).pow(2).sum() - - self.optimizer.zero_grad() - (mean_loss + self.hparams['penalty'] * penalty_value).backward() - self.optimizer.step() - +from torch.autograd import Variable +import random,math +class Identity(nn.Module): + """An identity layer""" + def __init__(self): + super(Identity, self).__init__() - return {'loss': mean_loss.item(), 'penalty': penalty_value.item()} + def forward(self, x): + return x - - -class SelfReg(ERM): +class MyRSC(ERM): def __init__(self, input_shape, num_classes, num_domains, hparams): - super(SelfReg, self).__init__(input_shape, num_classes, num_domains, + super(MyRSC, self).__init__(input_shape, num_classes, num_domains, hparams) - self.num_classes = num_classes - self.MSEloss = nn.MSELoss() - input_feat_size = self.featurizer.n_outputs - hidden_size = input_feat_size if input_feat_size==2048 else input_feat_size*2 - - self.cdpl = nn.Sequential( - nn.Linear(input_feat_size, hidden_size), - nn.BatchNorm1d(hidden_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_size, hidden_size), - nn.BatchNorm1d(hidden_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_size, input_feat_size), - nn.BatchNorm1d(input_feat_size) - ) - + self.avgpool = copy.deepcopy(self.featurizer.network.avgpool) + self.feature_num = self.featurizer.n_outputs + del self.featurizer.network.avgpool + self.featurizer.network.avgpool = Identity() + self.step = 0 + self.pecent = 0 + def update(self, minibatches, unlabeled=None): - + self.step += 1 all_x = torch.cat([x for x, y in minibatches]) + # labels all_y = torch.cat([y for _, y in minibatches]) - lam = np.random.beta(0.5, 0.5) - - batch_size = all_y.size()[0] - - # cluster and order features into same-class group - with torch.no_grad(): - sorted_y, indices = torch.sort(all_y) - sorted_x = torch.zeros_like(all_x) - for idx, order in enumerate(indices): - sorted_x[idx] = all_x[order] - intervals = [] - ex = 0 - for idx, val in enumerate(sorted_y): - if ex==val: - continue - intervals.append(idx) - ex = val - intervals.append(batch_size) - - all_x = sorted_x - all_y = sorted_y - - feat = self.featurizer(all_x) - proj = self.cdpl(feat) - - output = self.classifier(feat) - - # shuffle - output_2 = torch.zeros_like(output) - feat_2 = torch.zeros_like(proj) - output_3 = torch.zeros_like(output) - feat_3 = torch.zeros_like(proj) - ex = 0 - for end in intervals: - shuffle_indices = torch.randperm(end-ex)+ex - shuffle_indices2 = torch.randperm(end-ex)+ex - for idx in range(end-ex): - output_2[idx+ex] = output[shuffle_indices[idx]] - feat_2[idx+ex] = proj[shuffle_indices[idx]] - output_3[idx+ex] = output[shuffle_indices2[idx]] - feat_3[idx+ex] = proj[shuffle_indices2[idx]] - ex = end - - # mixup - output_3 = lam*output_2 + (1-lam)*output_3 - feat_3 = lam*feat_2 + (1-lam)*feat_3 - - # regularization - L_ind_logit = self.MSEloss(output, output_2) - L_hdl_logit = self.MSEloss(output, output_3) - L_ind_feat = 0.3 * self.MSEloss(feat, feat_2) - L_hdl_feat = 0.3 * self.MSEloss(feat, feat_3) - - cl_loss = F.cross_entropy(output, all_y) - C_scale = min(cl_loss.item(), 1.) - loss = cl_loss + C_scale*(lam*(L_ind_logit + L_ind_feat)+(1-lam)*(L_hdl_logit + L_hdl_feat)) - + x = self.featurizer.network(all_x) + x = x.view(x.size(0),self.feature_num,-1) + feature_map_size = math.floor(math.sqrt(x.size(2))) + x = x.view(x.size(0),self.feature_num,feature_map_size,feature_map_size) + + interval = 1000 + if self.step % interval == 0: + self.pecent = 3.0 / 10 + (self.step / interval) * 2.0 / 10 + + self.eval() + x_new = x.clone().detach() + x_new = Variable(x_new.data, requires_grad=True) + #print('x_new:',x_new.size()) + x_new_view = self.avgpool(x_new) + x_new_view = x_new_view.view(x_new_view.size(0), -1) + #print('size:',x_new_view.size()) + output = self.classifier(x_new_view) + class_num = output.shape[1] + index = all_y + num_rois = x_new.shape[0] + num_channel = x_new.shape[1] + H = x_new.shape[2] + HW = x_new.shape[2] * x_new.shape[3] + one_hot = torch.zeros((1), dtype=torch.float32).cuda() + one_hot = Variable(one_hot, requires_grad=False) + sp_i = torch.ones([2, num_rois]).long() + sp_i[0, :] = torch.arange(num_rois) + sp_i[1, :] = index + sp_v = torch.ones([num_rois]) + one_hot_sparse = torch.sparse.FloatTensor(sp_i, sp_v, torch.Size([num_rois, class_num])).to_dense().cuda() + one_hot_sparse = Variable(one_hot_sparse, requires_grad=False) + one_hot = torch.sum(output * one_hot_sparse) + self.zero_grad() + one_hot.backward() + grads_val = x_new.grad.clone().detach() + grad_channel_mean = torch.mean(grads_val.view(num_rois, num_channel, -1), dim=2) + channel_mean = grad_channel_mean + grad_channel_mean = grad_channel_mean.view(num_rois, num_channel, 1, 1) + spatial_mean = torch.sum(x_new * grad_channel_mean, 1) + spatial_mean = spatial_mean.view(num_rois, HW) + self.zero_grad() + + choose_one = random.randint(0, 9) + if choose_one <= 4: + # ---------------------------- spatial ----------------------- + spatial_drop_num = math.ceil(HW * 1 / 3.0) + th18_mask_value = torch.sort(spatial_mean, dim=1, descending=True)[0][:, spatial_drop_num] + th18_mask_value = th18_mask_value.view(num_rois, 1).expand(num_rois, 49) + mask_all_cuda = torch.where(spatial_mean > th18_mask_value, torch.zeros(spatial_mean.shape).cuda(), + torch.ones(spatial_mean.shape).cuda()) + mask_all = mask_all_cuda.reshape(num_rois, H, H).view(num_rois, 1, H, H) + else: + # -------------------------- channel ---------------------------- + vector_thresh_percent = math.ceil(num_channel * 1 / 3.2) + vector_thresh_value = torch.sort(channel_mean, dim=1, descending=True)[0][:, vector_thresh_percent] + vector_thresh_value = vector_thresh_value.view(num_rois, 1).expand(num_rois, num_channel) + vector = torch.where(channel_mean > vector_thresh_value, + torch.zeros(channel_mean.shape).cuda(), + torch.ones(channel_mean.shape).cuda()) + mask_all = vector.view(num_rois, num_channel, 1, 1) + + # ----------------------------------- batch ---------------------------------------- + cls_prob_before = F.softmax(output, dim=1) + x_new_view_after = x_new * mask_all + x_new_view_after = self.avgpool(x_new_view_after) + x_new_view_after = x_new_view_after.view(x_new_view_after.size(0), -1) + x_new_view_after = self.classifier(x_new_view_after) + cls_prob_after = F.softmax(x_new_view_after, dim=1) + + sp_i = torch.ones([2, num_rois]).long() + sp_i[0, :] = torch.arange(num_rois) + sp_i[1, :] = index + sp_v = torch.ones([num_rois]) + one_hot_sparse = torch.sparse.FloatTensor(sp_i, sp_v, torch.Size([num_rois, class_num])).to_dense().cuda() + before_vector = torch.sum(one_hot_sparse * cls_prob_before, dim=1) + after_vector = torch.sum(one_hot_sparse * cls_prob_after, dim=1) + change_vector = before_vector - after_vector - 0.0001 + change_vector = torch.where(change_vector > 0, change_vector, torch.zeros(change_vector.shape).cuda()) + th_fg_value = torch.sort(change_vector, dim=0, descending=True)[0][ + int(round(float(num_rois) * self.pecent))] + drop_index_fg = change_vector.gt(th_fg_value).long() + ignore_index_fg = 1 - drop_index_fg + not_01_ignore_index_fg = ignore_index_fg.nonzero()[:, 0] + mask_all[not_01_ignore_index_fg.long(), :] = 1 + + self.train() + mask_all = Variable(mask_all, requires_grad=True) + x = x * mask_all + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + #print(x.size()) + pred = self.classifier(x) + loss = F.cross_entropy(pred, all_y) self.optimizer.zero_grad() loss.backward() self.optimizer.step() - return {'loss': loss.item()} + + def predict(self, x): + x = self.featurizer.network(x) + x = x.view(x.size(0), self.feature_num, -1) + feature_map_size = math.floor(math.sqrt(x.size(2))) + x = x.view(x.size(0), self.feature_num, feature_map_size, feature_map_size) + x = self.avgpool(x) + #print(x.size()) + x = x.view(x.size(0), -1) + return self.classifier(x) \ No newline at end of file From ffdb6ed1dc91624fca6da882d8a11e145cbf73d4 Mon Sep 17 00:00:00 2001 From: dashu233 Date: Tue, 15 Jun 2021 11:18:28 +0800 Subject: [PATCH 2/3] command update --- domainbed/command_launchers.py | 35 ++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/domainbed/command_launchers.py b/domainbed/command_launchers.py index cfcdf3132..9af18a0a6 100644 --- a/domainbed/command_launchers.py +++ b/domainbed/command_launchers.py @@ -9,7 +9,6 @@ import subprocess import time import torch - def local_launcher(commands): """Launch commands serially on the local machine.""" for cmd in commands: @@ -33,7 +32,7 @@ def multi_gpu_launcher(commands): while len(commands) > 0: for gpu_idx in range(n_gpus): - proc = procs_by_gpu[gpu_idx] + proc = procs_by_gpu[gpu_idx] # stupid github if (proc is None) or (proc.poll() is not None): # Nothing is running on this GPU; launch a command. cmd = commands.pop(0) @@ -48,14 +47,42 @@ def multi_gpu_launcher(commands): if p is not None: p.wait() +def my_multi_gpu_launcher(commands,available_list = [0,1,2,3]): + """ + Launch commands on the local machine, using all GPUs in parallel. + """ + print('WARNING: using experimental multi_gpu_launcher.') + n_gpus = torch.cuda.device_count() + procs_by_gpu = [None]*n_gpus + + while len(commands) > 0: + for gpu_idx in available_list: + proc = procs_by_gpu[gpu_idx] + if (proc is None) or (proc.poll() is not None): + # Nothing is running on this GPU; launch a command. + cmd = commands.pop(0) + new_proc = subprocess.Popen( + f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', shell=True) + procs_by_gpu[gpu_idx] = new_proc + break + time.sleep(60) + + # Wait for the last few tasks to finish before returning + for p in procs_by_gpu: + if p is not None: + p.wait() + + + REGISTRY = { 'local': local_launcher, 'dummy': dummy_launcher, - 'multi_gpu': multi_gpu_launcher + 'multi_gpu': multi_gpu_launcher, + 'my_multi_gpu':my_multi_gpu_launcher } try: from domainbed import facebook facebook.register_command_launchers(REGISTRY) except ImportError: - pass + pass \ No newline at end of file From 50bfc8e6f42211565aefcbde6d1730e2be185d87 Mon Sep 17 00:00:00 2001 From: dashu233 Date: Tue, 15 Jun 2021 11:26:13 +0800 Subject: [PATCH 3/3] full pipeline --- domainbed/command_launchers.py | 2 +- domainbed/feature_checker.py | 68 +++++ domainbed/matrix_opt_for_train.py | 309 ++++++++++++++++++++ logs/collect_feature.py | 44 +++ logs/matrix_optimizer_new.py | 156 ++++++---- logs/rename.py | 122 ++++++++ main.py | 460 ++++++++++++++++++++++++++++++ my_launcher.py | 305 ++++++++++++++++++++ 8 files changed, 1404 insertions(+), 62 deletions(-) create mode 100644 domainbed/matrix_opt_for_train.py create mode 100644 logs/collect_feature.py create mode 100644 logs/rename.py create mode 100644 main.py create mode 100644 my_launcher.py diff --git a/domainbed/command_launchers.py b/domainbed/command_launchers.py index 9af18a0a6..9783497d9 100644 --- a/domainbed/command_launchers.py +++ b/domainbed/command_launchers.py @@ -32,7 +32,7 @@ def multi_gpu_launcher(commands): while len(commands) > 0: for gpu_idx in range(n_gpus): - proc = procs_by_gpu[gpu_idx] # stupid github + proc = procs_by_gpu[gpu_idx] if (proc is None) or (proc.poll() is not None): # Nothing is running on this GPU; launch a command. cmd = commands.pop(0) diff --git a/domainbed/feature_checker.py b/domainbed/feature_checker.py index 80e2df0e8..5281a30cf 100644 --- a/domainbed/feature_checker.py +++ b/domainbed/feature_checker.py @@ -31,6 +31,74 @@ def calculate_f_star(algorithm, loaders, device, test_envs, num_classes): mean[label] += fss_list[label][i] return [mean[label]/len(fss_list[label]) for label in range(num_classes)] +def feature_extractor_for_train(algorithm, loaders, device, num_classes,val='in'): + print("————————Calculate the feature distribution for train————————") + print("") + #eval_list = [] + + fssd_list_raw = [{} for _ in range(num_classes)] + + for eval_name, eval_loader in loaders: + if val in eval_name: + fssd_score, fssd_raw = get_feature( + algorithm, eval_loader, device, num_classes, None, False, True) + #eval_list.append(eval_name) + for label in range(num_classes): + fssd_list_raw[label][eval_name[:-3]]=fssd_raw[label] + #print(fssd_list_raw[0]['env0']) + print("————————Finish Calculating————————") + return fssd_list_raw + + +def feature_extractor_for_pipline(algorithm, loaders, device, num_classes, marker="",val='in'): + print("————————Calculate the feature distribution————————") + print("") + eval_list = [] + fssd_list_raw = [[]for i in range(num_classes)] + fssd_mean = [[]for i in range(num_classes)] + fssd_variance = [[]for i in range(num_classes)] + feature_mean = [[]for i in range(num_classes)] + feature_var = [[]for i in range(num_classes)] + return_feat = [{} for _ in range(num_classes)] + for eval_name, eval_loader in loaders: + if val in eval_name: + # Debug 不设置f_star + # fssd_score, fssd_raw = get_feature( + # algorithm, eval_loader, device, num_classes, f_star, False, True) + start = time.time() + fssd_score, fssd_raw = get_feature( + algorithm, eval_loader, device, num_classes, None, False, True) + for label in range(num_classes): + return_feat[label][eval_name[:4]]=fssd_raw[label] + eval_list.append(eval_name) + print("Extract feature in env " + eval_name + " use time " + str(round(time.time()-start, 3))) + + for label in range(num_classes): + fssd_list_raw[label].append(fssd_score[label]) + fssd_mean[label].append(torch.mean(fssd_list_raw[label][-1])) + fssd_variance[label].append( + torch.var(fssd_list_raw[label][-1])) + feature_mean[label].append(torch.mean(fssd_raw[label], dim=0)) + feature_var[label].append(torch.var(fssd_raw[label], dim=0)) + + save_raw_feature = True + if save_raw_feature: + + np.save(marker + "_"+eval_name+"_label"+str(label) + + ".npy", fssd_raw[label].cpu().numpy()) + + # 直观图片打印 + save_some_image = False + if save_some_image: + feature_num = list(range(20)) + for num in feature_num: + feature_set = fssd_raw[label][:, num].cpu().numpy() + plt.figure() + plt.hist(feature_set, bins='auto', density=True) + plt.savefig("feature_imgae/feature"+str(num)+"_label"+str(label)+"_" + + eval_list[-1]+".png") + plt.close() + return return_feat def feature_extractor(algorithm, loaders, device, num_classes, marker=""): print("————————Calculate the feature distribution————————") diff --git a/domainbed/matrix_opt_for_train.py b/domainbed/matrix_opt_for_train.py new file mode 100644 index 000000000..a8653b11d --- /dev/null +++ b/domainbed/matrix_opt_for_train.py @@ -0,0 +1,309 @@ +import torch as torch +import torch.nn as nn +import numpy as np +import time +import argparse + + +def shape_to_matrix(feature_num, env_list, label_num, max_data, data_len, data, device='cuda'): + env_num = len(env_list) + matrix = torch.zeros([env_num, label_num, max_data, + feature_num], device=device) + #print('env_list',env_list) + for env in range(env_num): + for label in range(label_num): + matrix[env][label][0:data_len[env, label] + ] = data[label][env_list[env]] + #print('data_len:',data_len[env, label]) + #print('data:',data[label][env_list[env]]) + #print('______env______:',env) + #print(matrix) + return matrix + + +class opt_kde(torch.nn.Module): + def __init__(self, env_list, train_env, num_classes, feature_num,data,percent=0.5, + sample_size=1000,device='cuda'): + self.sample_size = sample_size + self.device = device + self.envs = env_list + self.train_env = train_env + self.envs_num = len(self.envs) + self.mask = None + self.percent = percent + + # 准备初始化数据 + data_len = np.zeros( + (len(env_list), num_classes), dtype=np.int32) + for i in range(len(env_list)): + for j in range(num_classes): + data_len[i][j] = len(data[j][env_list[i]]) + #print('data:',data) + matrix = shape_to_matrix(feature_num=feature_num, env_list=env_list, label_num=num_classes, + max_data=int( + max([max(w) for w in data_len])), data_len=data_len, data=data, + device=device) + + # 确认参数匹配 + self.feature_num = matrix.shape[3] + assert self.feature_num == feature_num, "Error when loading feature" + self.label_num = matrix.shape[1] + assert self.label_num == num_classes, "Error when dealing with labels" + self.max_sample = matrix.shape[2] + assert matrix.shape[0] == len( + env_list), "length of envs in data does match provided envs" + + self.matrix = matrix + #print('matrix', self.matrix) + + self.data_len = torch.tensor(data_len, dtype=torch.float32) + self.data_mask = torch.ones( + (self.envs_num, self.label_num, self.max_sample), dtype=torch.int32).to(self.device) + for env in range(self.envs_num): + for label in range(self.label_num): + self.data_mask[env, label, data_len[env, label]:] -= 1 + self.len_unsqueeze = self.data_len.unsqueeze(2).to(self.device) + + self.bandwidth = 1.06 * \ + self.max_sample ** (-1. / (1 + 4)) * \ + torch.std(matrix, dim=2).mean().clone().detach() + self.offset = torch.exp(-0.5 / (self.bandwidth ** 2)).to(self.device) + # self.sample_size = int(sample_size * (torch.max(matrix) - torch.min(matrix)).cpu().item()) + + self.batch_len = 1 + self.batch_size = (self.sample_size + + self.batch_len - 1) // self.batch_len + + self.params = torch.eye( + self.feature_num, requires_grad=True).to(device) + + def normalize(self): # do normalization in params + self.params = self.params / torch.sqrt(torch.sum(self.params ** 2, dim=0, keepdim=True)).detach().clamp_min_( + 1e-3) + + def forward(self, cal_info=False, verbose=False,set_mask=False): + # matmul matrix params, s.t. check the results in this linear combination + matrix = torch.matmul(self.matrix, self.params).detach().unsqueeze(dim=-1) + left, right = torch.min(matrix).cpu( + ).item(), torch.max(matrix).cpu().item() + + if verbose: + print("sample message: from %.4f to %.4f, size is %d" % + (left, right, self.sample_size)) + delta = (right - left) / self.sample_size + x_gird = torch.linspace(left, right, self.sample_size).to(self.device) + divisor = np.sqrt(2 * np.pi) * self.bandwidth + store_dis = torch.zeros( + (self.envs_num * self.envs_num, self.label_num, self.feature_num)).to(self.device) + if cal_info: + store_info = torch.zeros(( + self.label_num * self.label_num, self.envs_num, self.feature_num + )).to(self.device) + reduce_zeros = torch.tensor( + self.max_sample, dtype=torch.float32).to(self.device) + + index = 0 + train_index = [] + for envi in range(self.envs_num): + for envj in range(self.envs_num): + if self.envs[envi] in self.train_env and self.envs[envj] in self.train_env: + train_index.append(index) + index += 1 + + timing = 1000 // self.batch_len + for batch in range(self.batch_size): + if batch % timing == 0: + start = time.time() + points = x_gird[batch * + self.batch_len:min((batch + 1) * self.batch_len, self.sample_size)].reshape((1, -1)) + reducer = (torch.sum(torch.pow(self.offset, (matrix - points) ** 2), dim=2) - + ((reduce_zeros - self.len_unsqueeze) * + torch.pow(self.offset, points ** 2)).unsqueeze(dim=2) + ) / self.len_unsqueeze.unsqueeze(dim=3) + + dis_expand = reducer.expand( + (self.envs_num, self.envs_num, self.label_num, self.feature_num, reducer.shape[-1])) + store_dis += torch.sum(torch.abs(dis_expand - dis_expand.permute(1, 0, 2, 3, 4)), dim=-1).reshape( + (-1, self.label_num, self.feature_num)) / divisor + #print(store_dis) + if cal_info: + info_expand = reducer.permute(1, 0, 2, 3).expand( + (self.label_num, self.label_num, self.envs_num, self.feature_num, reducer.shape[-1])) + store_info += torch.sum(torch.abs(info_expand - info_expand.permute(1, 0, 2, 3, 4)), dim=-1).reshape( + (-1, self.envs_num, self.feature_num)) / divisor + + if batch % timing == timing - 1 and verbose: + print("epoch %d, avg time: %f" % + ((batch + 1) * self.batch_len, (time.time() - start) / timing / self.batch_len)) + # print("pure cal:" + str(cal_time / timing/self.batch_len)) + + test_results = (store_dis * delta / 2).max(dim=0)[0] + train_results = (store_dis[train_index] * delta / 2).max(dim=0)[0] + if verbose: + print("finish forward once.") + + if set_mask: + feature_dis = train_results.max(dim=0)[0].view(-1) + self.mask = torch.topk(feature_dis,int(self.percent*len(feature_dis)),largest=True)[1] + # find smallest channel, set param to 0 + print('mask len:',len(self.mask)) + if len(self.mask) == 0: + return torch.eye(self.params.size(0),device=self.device) + save_param = self.params.detach().clone() + save_param[:,self.mask] = 0 + inverse_param = torch.inverse(self.params) + inverse_param[self.mask,:]=0 + res = torch.matmul(save_param,inverse_param) + print('diff from identity:',torch.norm(res - torch.eye(self.feature_num,device=self.device))) + return res + + if cal_info: + # should consider min env s.t. this to feature is exhibit, and select the biggest label pair + # train_info = (store_info * delta / 2).max(dim=0)[0] + # return a (1, feature_num) dimension + train_info = (store_info * delta / + 2).min(dim=1)[0].max(dim=0)[0].reshape((1, -1)) + return { + "train_results": train_results, + "test_results": test_results, + "train_info": train_info, + "train_dis": torch.mean(train_results.max(dim=0)[0]), + "test_dis": torch.mean(test_results.max(dim=0)[0]) + } + return { + "train_results": train_results, + "test_results": test_results, + "train_info": None, + "train_dis": torch.mean(train_results.max(dim=0)[0]), + "test_dis": torch.mean(test_results.max(dim=0)[0]) + } + + @torch.no_grad() + def pca(self): + mean_value = torch.mean(self.matrix @ self.params, dim=2) * ( + self.max_sample / self.len_unsqueeze) + # mean_valuse is of shape (env,label,feature) + x = mean_value.unsqueeze(1) + y = mean_value.unsqueeze(0) + feat = (x - y).view(-1, self.feature_num) + feat1 = feat.unsqueeze(2) + feat2 = feat.unsqueeze(1) + mat = torch.mean(feat1 * feat2, dim=0) + eig = torch.eig(mat, eigenvectors=True) + print('min eig of data:{}'.format(torch.min(eig[0]).item())) + lam = torch.diag(torch.sqrt(eig[0])) + self.params = lam * eig[1] + + def backward(self, backward_method='mean', lr=1): + if backward_method == 'L1': + # 这里考虑对env取完max之后对label做mean,这样子可以增加数据量 + # argmax: label x feature → train_index上的index + # 表示的是对这个params的分量,这个label,是哪两个环境参与了max dis的计算 + print("L1 backward is not ready, please use mean method to backward") + exit() + cluster_index = torch.gather(torch.from_numpy( + np.array(train_index, dtype=np.longlong)).to(self.device), 0, argmax.view(-1)).reshape((-1, 1)) + index = torch.cat([cluster_index // 4, cluster_index % 4], dim=1).reshape((-1, 2, 1)) + # index is (label*feature)*2(represent 2 env taken by this pair) + + update_matrix = self.matrix.permute(1, 3, 0, 2).reshape(( + -1, self.envs_num, self.max_sample)).gather(dim=1, + index=index.expand(index.shape[0], index.shape[1], + self.max_sample)).reshape(( + self.label_num, self.feature_num, 2, self.max_sample + )) + + argmax = store_dis[train_index].reshape( + (-1, self.feature_num)).max(dim=0)[1].reshape((-1, 1)) + # TODO: add appropriate index + index = torch.cat([argmax % self.label_num, ]) + index = [(i, argmax[i] % self.label_num, [train_index[w] // 4, train_index[w] % 4]) + for i, w in enumerate((argmax // self.label_num))] + update_matrix = matrix.squeeze(-1).permute(3, 1, 0, 2)[index] + print(update_matrix.shape) + elif backward_method == 'mean': + mean_value = torch.mean(self.matrix @ self.params, dim=2) * ( + self.max_sample / self.len_unsqueeze) + train_env_index = [w for w in range(self.envs_num) if self.envs[w] in self.train_env] + variance = torch.var(mean_value[train_env_index], dim=0) + grad = torch.autograd.grad(variance.mean(), self.params) + self.params -= lr * grad[0] + self.normalize() + + def eig_val(self): # return sorted eig value, to check whether degenerate + eigs = torch.eig(self.params) + return np.sort(eignormalizes[0].detach().cpu().numpy()[:, 0]) + + +class opt_mmd(torch.nn.Module): + def __init__(self, matrix, data_len, sample_size, env_list, device='cuda'): + print("This method is not prepared. Please use opt_kde instead.") + exit() + self.device = device + + self.feature_num = matrix.shape[3] + self.label_num = matrix.shape[1] + self.max_sample = matrix.shape[2] + assert matrix.shape[0] == len( + env_list), "length of envs in data does match provided envs" + self.sample_size = [torch.tensor( + 10 ** (gamma)).to(self.device) for gamma in range(-3, 4)] + + self.matrix = matrix # Label x Env x Data_num x feature_num + self.data_len = torch.tensor( + data_len, dtype=torch.float32).to(self.device) + self.envs = env_list + self.envs_num = len(self.envs) + + self.data_mask = torch.ones( + (self.envs_num, self.label_num, self.max_sample)).to(self.device) + for env in range(self.envs_num): + for label in range(self.label_num): + self.data_mask[env, label, data_len[env, label]:] -= 1.0 + + self.params = torch.eye( + self.feature_num).to(device) + + self.global_MMD = True + + def my_cdist(self, x1, x2): + x1_norm = x1.pow(2).sum(dim=-1, keepdim=True) + x2_norm = x2.pow(2).sum(dim=-1, keepdim=True) + res = torch.addmm(x2_norm.transpose(-2, -1), + x1, + x2.transpose(-2, -1), alpha=-2).add_(x1_norm) + return res.clamp_min_(1e-30) + + def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100, + 1000]): + D = self.my_cdist(x, y) + K = torch.zeros_like(D) + + for g in gamma: + K.add_(torch.exp(D.mul(-g))) + + return K + + def mmd(self, x, y): # calculate the value of each two point in x_list + + Kxx = self.gaussian_kernel(x, x).mean() + Kyy = self.gaussian_kernel(y, y).mean() + Kxy = self.gaussian_kernel(x, y).mean() + return Kxx + Kyy - 2 * Kxy + + + def forward(self): + # Global MMD = \mean_{i,j} \sum_{g} \mean{env, env'} + matrix = self.matrix @ self.params + Kenv = torch.zeros((self.envs_num, 1)).to(self.device) + for env in range(self.envs_num): # 先计算自己和自己的 + x_norm = (matrix[env].pow(2).sum( + dim=-1, keepdim=True)) @ self.data_mask[env].unsqueeze(-1).transpose(-2, -1) + res = -2 * \ + matrix[env] @ matrix[env].transpose(-2, -1) + \ + x_norm.transpose(-2, -1) + x_norm + res.clamp_min_(1e-30) + for g in self.sample_size: # MMD中的kernel + Kenv[env] += torch.mean(torch.exp(res.mul(-g)).mean(dim=[-1, -2]).add_(-1) * torch.pow( + self.max_sample / self.data_len[env], 2) + 1) + print(Kenv) \ No newline at end of file diff --git a/logs/collect_feature.py b/logs/collect_feature.py new file mode 100644 index 000000000..19b900005 --- /dev/null +++ b/logs/collect_feature.py @@ -0,0 +1,44 @@ +import pandas as pd +import os + +dataset = 'ColoredMNIST' +env_list = [2] +for env in env_list: + dir_list = ['{}_CORAL_test_env{}'.format(dataset, env), '{}_ERM_test_env{}'.format(dataset, env), + '{}_GroupDRO_test_env{}'.format(dataset, env), + '{}_IRM_test_env{}'.format(dataset, env), + '{}_Mixup_test_env{}'.format(dataset, env) + ] + before = False + for before in [True]: + print(before) + output_file = 'data_collection_before_{}_env{}.csv'.format(dataset, env) if before \ + else 'data_collection_{}_env{}.csv'.format(dataset, env) + dt = None + + for dir0 in dir_list: + print(dir0) + assert dataset in dir0, 'not correct dataset' + assert str(env) in dir0, 'not correct env' + if not os.path.isdir(dir0) or 'env' not in dir0: + continue + file_list = sorted(os.listdir(dir0)) + for sub_name in file_list: + if not os.path.isdir(os.path.join(dir0, sub_name)): + continue + sub_file_list = os.listdir(os.path.join(dir0, sub_name)) + try: + fname = 'before_result.csv' if before else 'result.csv' + res = pd.read_csv(os.path.join(dir0, sub_name, fname), index_col=False) + if dt is None: + dt = res + else: + dt = pd.concat([dt, res], axis=0) + #print('count') + except FileNotFoundError as e: + print(e) + print(dir0, sub_name) + if dt is not None: + dt.to_csv(output_file, index=False) + + diff --git a/logs/matrix_optimizer_new.py b/logs/matrix_optimizer_new.py index a22e5f4b4..b3a8280b3 100644 --- a/logs/matrix_optimizer_new.py +++ b/logs/matrix_optimizer_new.py @@ -5,15 +5,22 @@ import argparse -def shape_to_matrix(feature_num, env_list, label_num, max_data, data_len, data, device='cuda'): +def shape_to_matrix(feature_num, env_list, label_num, max_data, data_len, data, device='cuda',input_tensor=False): env_num = len(env_list) - matrix = np.zeros((env_num, label_num, max_data, + if input_tensor: + matrix = torch.zeros((env_num,label_num,max_data,feature_num),device=device) + else: + matrix = np.zeros((env_num, label_num, max_data, feature_num), dtype=np.float32) for env in range(env_num): for label in range(label_num): matrix[env][label][0:data_len[env, label] - ] = data[label][env_list[env]] - return torch.from_numpy(matrix).to(device) + ] = data[label][env_list[env]] + if input_tensor: + return matrix + else: + return torch.from_numpy(matrix).to(device) + def torch_to_numpy(d): return { @@ -22,14 +29,18 @@ def torch_to_numpy(d): if d[key] is not None } + class opt_kde(torch.nn.Module): - def __init__(self, env_list, train_env, num_classes, feature_num, args, data): - self.sample_size = args.sample_size - self.device = args.device + def __init__(self, env_list, train_env, num_classes, feature_num, data,sample_size,device): + self.sample_size = sample_size + self.device = device self.envs = env_list self.train_env = train_env self.envs_num = len(self.envs) self.train_env_index = [i for i in range(self.envs_num) if env_list[i] in self.train_env] + input_tensor = False + if isinstance(data[0][env_list[0]],torch.Tensor): + input_tensor = True # 准备初始化数据 data_len = np.zeros( @@ -37,10 +48,10 @@ def __init__(self, env_list, train_env, num_classes, feature_num, args, data): for i in range(len(env_list)): for j in range(num_classes): data_len[i][j] = data[j][env_list[i]].shape[0] - matrix = shape_to_matrix(feature_num=feature_num, env_list = env_list, label_num=num_classes, - max_data=int( - max([max(w) for w in data_len])), data_len=data_len, data=data, - device=args.device) + matrix = shape_to_matrix(feature_num=feature_num, env_list=env_list, label_num=num_classes, + max_data=int( + max([max(w) for w in data_len])), data_len=data_len, data=data, + device=device,input_tensor=input_tensor) # 确认参数匹配 self.feature_num = matrix.shape[3] @@ -50,46 +61,47 @@ def __init__(self, env_list, train_env, num_classes, feature_num, args, data): self.max_sample = matrix.shape[2] assert matrix.shape[0] == len( env_list), "length of envs in data does match provided envs" - + std = torch.std(matrix, dim=2).mean().clone().detach() self.matrix = matrix / std self.data_len = torch.tensor(data_len, dtype=torch.float32) self.data_mask = torch.ones( - (self.envs_num, self.label_num, self.max_sample),dtype=torch.int32).to(self.device) + (self.envs_num, self.label_num, self.max_sample), dtype=torch.int32).to(self.device) for env in range(self.envs_num): for label in range(self.label_num): self.data_mask[env, label, data_len[env, label]:] -= 1 self.len_unsqueeze = self.data_len.unsqueeze(2).to(self.device) self.bandwidth = 1.06 * \ - self.max_sample ** (-1. / (1 + 4)) * \ - torch.std(self.matrix, dim=2).mean().clone().detach() - + self.max_sample ** (-1. / (1 + 4)) * \ + torch.std(self.matrix, dim=2).mean().clone().detach() + self.offset = torch.exp(-0.5 / (self.bandwidth ** 2)).to(self.device) - #self.sample_size = int(sample_size * (torch.max(matrix) - torch.min(matrix)).cpu().item()) + # self.sample_size = int(sample_size * (torch.max(matrix) - torch.min(matrix)).cpu().item()) - self.batch_len = args.batch_len + self.batch_len = 10 self.batch_size = (self.sample_size + self.batch_len - 1) // self.batch_len self.params = torch.eye( - self.feature_num,requires_grad=True).to(args.device) - - def normalize(self): # do normalization in params - self.params = self.params / torch.sqrt(torch.sum(self.params**2,dim=0,keepdim=True)).detach().clamp_min_(1e-3) - - def forward(self, cal_info=False,verbose=False,whether_backward=False,lr=None): - #backward = True - #lr = 10 + self.feature_num, requires_grad=True).to(device) + + def normalize(self): # do normalization in params + self.params = self.params / torch.sqrt(torch.sum(self.params ** 2, dim=0, keepdim=True)).detach().clamp_min_( + 1e-3) + + def forward(self, cal_info=False, verbose=False, whether_backward=False, lr=None,use_mean_info=True): + # backward = True + # lr = 10 if whether_backward == True: cal_info = False verbose = False accum_grad = torch.zeros_like(self.params).to(self.device) - #optimizer = torch.optim.SGD([self.params], lr=lr) - + # optimizer = torch.optim.SGD([self.params], lr=lr) + # matmul matrix params, s.t. check the results in this linear combination - matrix = torch.matmul(self.matrix, self.params).unsqueeze(dim=-1).detach() + matrix = torch.matmul(self.matrix, self.params).unsqueeze(dim=-1).detach() left, right = torch.min(matrix).cpu( ).item(), torch.max(matrix).cpu().item() if verbose: @@ -106,7 +118,6 @@ def forward(self, cal_info=False,verbose=False,whether_backward=False,lr=None): )).to(self.device) reduce_zeros = torch.tensor( self.max_sample, dtype=torch.float32).to(self.device) - index = 0 train_index = [] @@ -132,28 +143,34 @@ def forward(self, cal_info=False,verbose=False,whether_backward=False,lr=None): if batch % timing == 0: start = time.time() if whether_backward: - matrix = torch.matmul(self.matrix, self.params).unsqueeze(dim=-1) + matrix = torch.matmul(self.matrix, self.params).unsqueeze(dim=-1) points = x_gird[batch * self.batch_len:min((batch + 1) * self.batch_len, self.sample_size)].reshape((1, -1)) + #print( + #'over:',torch.sum(torch.pow(self.offset, (matrix - points) ** 2), dim=2) - + #((reduce_zeros - self.len_unsqueeze) * + # torch.pow(self.offset, points ** 2)).unsqueeze(dim=2)) + reducer = (torch.sum(torch.pow(self.offset, (matrix - points) ** 2), dim=2) - ((reduce_zeros - self.len_unsqueeze) * torch.pow(self.offset, points ** 2)).unsqueeze(dim=2) ) / self.len_unsqueeze.unsqueeze(dim=3) + #print('dw:',self.len_unsqueeze.unsqueeze(dim=3)) + #print('reducer:',reducer) dis_expand = reducer.expand( (self.envs_num, self.envs_num, self.label_num, self.feature_num, reducer.shape[-1])) - + if whether_backward: adder = torch.sum(torch.abs(dis_expand - dis_expand.permute(1, 0, 2, 3, 4)), dim=-1).reshape( (-1, self.label_num, self.feature_num)) / divisor store_dis = (store_dis + adder).detach() loss = (adder).mean() * delta / 2 - accum_grad += torch.autograd.grad(loss,self.params)[0].detach() + accum_grad += torch.autograd.grad(loss, self.params)[0].detach() else: store_dis += torch.sum(torch.abs(dis_expand - dis_expand.permute(1, 0, 2, 3, 4)), dim=-1).reshape( (-1, self.label_num, self.feature_num)) / divisor - if cal_info: info_expand = reducer.permute(1, 0, 2, 3).expand( @@ -170,40 +187,44 @@ def forward(self, cal_info=False,verbose=False,whether_backward=False,lr=None): train_results = (store_dis[train_index] * delta / 2).max(dim=0)[0] if verbose: print("finish forward once.") - + if whether_backward: self.params -= lr * accum_grad self.normalize() if cal_info: # should consider min env s.t. this to feature is exhibit, and select the biggest label pair - #train_info = (store_info * delta / 2).max(dim=0)[0] + # train_info = (store_info * delta / 2).max(dim=0)[0] # return a (1, feature_num) dimension - train_info_raw = (store_info[info_index][:,self.train_env_index,:] * delta / - 2).min(dim=1)[0].mean(dim=0).reshape((-1)) + if use_mean_info: + train_info_raw = (store_info[info_index][:, self.train_env_index, :] * delta / + 2).min(dim=1)[0].mean(dim=0).reshape((-1)) + else: + train_info_raw = (store_info[info_index][:, self.train_env_index, :] * delta / + 2).min(dim=1)[0].max(dim=0).reshape((-1)) return { "train_results": train_results, "test_results": test_results, "train_info": train_info_raw, - "train_dis":torch.mean(train_results.max(dim=0)[0]), - "test_dis":torch.mean(test_results.max(dim=0)[0]), - "info_mean":train_info_raw.mean() + "train_dis": torch.mean(train_results.max(dim=0)[0]), + "test_dis": torch.mean(test_results.max(dim=0)[0]), + "info_mean": train_info_raw.mean() } return { "train_results": train_results, "test_results": test_results, "train_info": None, - "train_dis":torch.mean(train_results.max(dim=0)[0]), - "test_dis":torch.mean(test_results.max(dim=0)[0]), - "info_mean":None, + "train_dis": torch.mean(train_results.max(dim=0)[0]), + "test_dis": torch.mean(test_results.max(dim=0)[0]), + "info_mean": None, } - def backward(self, backward_method = 'mean', lr = 1): + def backward(self, backward_method='mean', lr=1): if backward_method == 'L1': # 这里考虑对env取完max之后对label做mean,这样子可以增加数据量 # argmax: label x feature → train_index上的index # 表示的是对这个params的分量,这个label,是哪两个环境参与了max dis的计算 - results = torch_to_numpy(self.forward(whether_backward = True,lr=lr)) + results = torch_to_numpy(self.forward(whether_backward=True, lr=lr)) print("Before training, Train dis is %.4f, test dis is %.4f" % (results['train_dis'], results['test_dis'])) @@ -212,14 +233,11 @@ def backward(self, backward_method = 'mean', lr = 1): np.array(train_index,dtype=np.longlong)).to(self.device), 0, argmax.view(-1)).reshape((-1, 1)) index = torch.cat([cluster_index // 4, cluster_index % 4], dim=1).reshape((-1,2,1)) # index is (label*feature)*2(represent 2 env taken by this pair) - update_matrix = self.matrix.permute(1, 3, 0, 2).reshape(( -1, self.envs_num,self.max_sample)).gather(dim=1, index=index.expand(index.shape[0],index.shape[1],self.max_sample)).reshape(( self.label_num,self.feature_num,2,self.max_sample )) - - argmax = store_dis[train_index].reshape( (-1, self.feature_num)).max(dim=0)[1].reshape((-1, 1)) # TODO: add appropriate index @@ -231,17 +249,33 @@ def backward(self, backward_method = 'mean', lr = 1): ''' elif backward_method == 'mean': - mean_value = torch.mean(self.matrix @ self.params, dim = 2) * ( - self.max_sample / self.len_unsqueeze) + mean_value = torch.mean(self.matrix @ self.params, dim=2) * ( + self.max_sample / self.len_unsqueeze) train_env_index = [w for w in range(self.envs_num) if self.envs[w] in self.train_env] - variance = torch.var(mean_value[train_env_index],dim=0) - grad = torch.autograd.grad(variance.mean(),self.params) + variance = torch.var(mean_value[train_env_index], dim=0) + grad = torch.autograd.grad(variance.mean(), self.params) self.params -= lr * grad[0] self.normalize() - + + @torch.no_grad() + def pca(self): + mean_value = torch.mean(self.matrix @ self.params, dim=2) * ( + self.max_sample / self.len_unsqueeze) + # mean_valuse is of shape (env,label,feature) + x = mean_value.unsqueeze(1) + y = mean_value.unsqueeze(0) + feat = (x-y).view(-1,self.feature_num) + feat1 = feat.unsqueeze(2) + feat2 = feat.unsqueeze(1) + mat = torch.mean(feat1*feat2,dim=0) + eig = torch.eig(mat,eigenvectors=True) + print('min eig of data:{}'.format(torch.min(eig[0]).item())) + lam = torch.diag(torch.sqrt(eig[0])) + self.params = lam*eig[1] + def eig_val(self): # return sorted eig value, to check whether degenerate - eigs = torch.eig(self.params) - return np.sort(eigs[0].detach().cpu().numpy()[:,0]) + eigs = torch.eig(self.params) + return np.sort(eigs[0].detach().cpu().numpy()[:, 0]) class opt_mmd(torch.nn.Module): @@ -256,7 +290,7 @@ def __init__(self, matrix, data_len, sample_size, env_list, device='cuda'): assert matrix.shape[0] == len( env_list), "length of envs in data does match provided envs" self.sample_size = [torch.tensor( - 10**(gamma)).to(self.device) for gamma in range(-3, 4)] + 10 ** (gamma)).to(self.device) for gamma in range(-3, 4)] self.matrix = matrix # Label x Env x Data_num x feature_num self.data_len = torch.tensor( @@ -304,14 +338,14 @@ def forward(self): # Global MMD = \mean_{i,j} \sum_{g} \mean{env, env'} matrix = self.matrix @ self.params Kenv = torch.zeros((self.envs_num, 1)).to(self.device) - for env in range(self.envs_num): # 先计算自己和自己的 + for env in range(self.envs_num): # 先计算自己和自己的 x_norm = (matrix[env].pow(2).sum( dim=-1, keepdim=True)) @ self.data_mask[env].unsqueeze(-1).transpose(-2, -1) res = -2 * \ - matrix[env] @ matrix[env].transpose(-2, -1) + \ - x_norm.transpose(-2, -1) + x_norm + matrix[env] @ matrix[env].transpose(-2, -1) + \ + x_norm.transpose(-2, -1) + x_norm res.clamp_min_(1e-30) for g in self.sample_size: # MMD中的kernel Kenv[env] += torch.mean(torch.exp(res.mul(-g)).mean(dim=[-1, -2]).add_(-1) * torch.pow( self.max_sample / self.data_len[env], 2) + 1) - print(Kenv) + print(Kenv) \ No newline at end of file diff --git a/logs/rename.py b/logs/rename.py new file mode 100644 index 000000000..36fc687c1 --- /dev/null +++ b/logs/rename.py @@ -0,0 +1,122 @@ +import re +import json +import pandas as pd +import os +import math +import copy +dir_list = os.listdir('./') +before = False +dataset = 'ColoredMNIST' +forbid = False +env_list = [2] +def floatequ(a,b): + return math.fabs(a-b) < 1e-8 +for env in env_list: + for before in [True]: + output_file = 'renamed/collection_{}_test{}_before_feanum.csv'.format(dataset, env) if before \ + else 'renamed/collection_{}_test{}_drag_feanum.csv'.format(dataset, env) + read_file = 'collection/data_collection_before_{}_env{}.csv'.format(dataset, env) if before else \ + 'collection/data_collection_{}_env{}.csv'.format(dataset, env) + dt = None + + read_file_frame = pd.read_csv(read_file) + names = read_file_frame['name'] + lr_list = [] + step_list = [] + pn_list = [] + sd_list = [] + method_list = [] + pn_name_list = ["groupdro_eta", "irm_lambda", "mixup_alpha", "vrex_lambda", 'mmd_gamma', 'no_pn'] + index_list = [] + + idx = -1 + for name in names: + idx += 1 + #print(name) + tp = re.search(r'\{[^\}]+\}', name) + dp = name[tp.span()[0]:tp.span()[1]] + dp = dp.replace('=', ':').replace('*', '.').replace(r'_"', r',"') + dt = json.loads(dp) + try: + assert 'lr' in dt + except Exception as e: + print(e) + + tmp_lr = dt['lr'] + break_flag = False + if forbid and (not (floatequ(tmp_lr, 1e-4) or floatequ(tmp_lr, 5e-5))): + continue + for pn in pn_name_list: + if pn == 'no_pn': + tmp_pn = 0 + if pn in dt: + if pn == "groupdro_eta": + if forbid and (not (floatequ(dt[pn],0.1) or floatequ(dt[pn],0.01))): + break_flag=True + break + tmp_pn = dt[pn] + break + if break_flag: + continue + tp = re.search(r'step_[0-9]*', name) + tmp_step = int(name[tp.span()[0] + 5:tp.span()[1]]) + if forbid and (tmp_step not in [2500, 5000]): + continue + + tp = re.search(r'_[0-9]*_', name) + tmp_sd = int(name[tp.span()[0] + 1:tp.span()[1] - 1]) + if forbid and (tmp_sd not in [0, 1, 2, 3, 4]): + continue + + tp = re.search(r'\A[A-Za-z]*_', name) + dp = name[tp.span()[0]: tp.span()[1] - 1] + method_list.append(dp) + lr_list.append(tmp_lr) + pn_list.append(tmp_pn) + step_list.append(tmp_step) + sd_list.append(tmp_sd) + + index_list.append(idx) + + #print('added') + print(len(index_list)) + read_file_frame2 = copy.deepcopy(read_file_frame.iloc[index_list]) + #print(read_file_frame2.shape) + + read_file_frame2['algorithm'] = method_list + read_file_frame2['step'] = step_list + read_file_frame2['lr'] = lr_list + read_file_frame2['penalty'] = pn_list + + read_file_frame2['seed'] = sd_list + + first_list = ['algorithm', 'step', 'lr', 'penalty', 'seed'] + last_list = ['name'] + + column_list = list(read_file_frame2) + for _ in first_list: + column_list.remove(_) + for _ in last_list: + column_list.remove(_) + + column_list = first_list + column_list + last_list + read_file_frame2 = read_file_frame2[column_list] + + # import math + thr_list = [] + for thr in range(9): + thr_list.append(round(thr * 0.05, 2)) + cls = {} + for thr in thr_list: + cls['test_dis_{}'.format(thr)] = '' + cls['train_info_{}'.format(thr)] = '' + cls['train_dis_{}'.format(thr)] = thr + read_file_frame2 = read_file_frame2.rename(columns=cls) + read_file_frame2.index = [_ for _ in range(len(index_list))] + read_file_frame2.to_csv(output_file) + + #print(cls) + #print(read_file_frame2.columns) + + + diff --git a/main.py b/main.py new file mode 100644 index 000000000..cfd82dacb --- /dev/null +++ b/main.py @@ -0,0 +1,460 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import argparse +import collections +import json +import os +import random +import sys +import time +import uuid +import copy + +import numpy as np +import PIL +import torch +import torchvision +import torch.utils.data + +from domainbed import datasets +from domainbed import hparams_registry +from domainbed import algorithms +from domainbed.lib import misc +from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader +from domainbed.feature_checker import feature_extractor_for_pipline,feature_extractor_for_train +import domainbed.matrix_opt_for_train as moft +from logs.matrix_optimizer_new import opt_kde + + +def torch_to_numpy(d): + return { + key: d[key].cpu().numpy() + for key in d.keys() + if d[key] is not None + } + +def to_str(lis): + s = "" + for w in lis: + s = s + str(w).ljust(10," ") + ", " + return s + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Domain generalization') + parser.add_argument('--data_dir', type=str) + parser.add_argument('--dataset', type=str, default="RotatedMNIST") + parser.add_argument('--algorithm', type=str, default="ERM") + parser.add_argument('--task', type=str, default="domain_generalization", + help='domain_generalization | domain_adaptation') + parser.add_argument('--hparams', type=str, + help='JSON-serialized hparams dict') + parser.add_argument('--hparams_seed', type=int, default=0, + help='Seed for random hparams (0 means "default hparams")') + parser.add_argument('--trial_seed', type=int, default=0, + help='Trial number (used for seeding split_dataset and ' + 'random_hparams).') + parser.add_argument('--seed', type=int, default=0, + help='Seed for everything else') + parser.add_argument('--steps', type=int, default=None, + help='Number of steps. Default is dataset-dependent.') + parser.add_argument('--checkpoint_freq', type=int, default=1000, + help='Checkpoint every N steps. Default is dataset-dependent.') + parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) + parser.add_argument('--output_dir', type=str, default="train_output/") + parser.add_argument('--holdout_fraction', type=float, default=0.2) + parser.add_argument('--uda_holdout_fraction', type=float, default=0) + parser.add_argument('--skip_model_save', action='store_true') + parser.add_argument('--save_feature_every_checkpoint', action='store_true') + parser.add_argument('--extract_feature', type=str, default=None) # 是否extract每个特征的分布 + parser.add_argument('--output_result_file', type=str, default=None) # 是否extract每个特征的分布 + parser.add_argument('--follow_plot', action='store_true') + parser.add_argument('--start_step',type=int,default=0) + parser.add_argument('--val',type = str, default='in') + + + args = parser.parse_args() + threshold_list = [round(0.05 * i, 2) for i in range(9)] + title_flag = True + # If we ever want to implement checkpointing, just persist these values + # every once in a while, and then load them from disk here. + start_step = 0 + algorithm_dict = None + if args.extract_feature is not None: + args.output_dir = args.output_dir + args.extract_feature if args.output_dir[ + -1] == "/" else args.output_dir + "/" + args.extract_feature + + os.makedirs(args.output_dir, exist_ok=True) + sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) + sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) + + print("Environment:") + print("\tPython: {}".format(sys.version.split(" ")[0])) + print("\tPyTorch: {}".format(torch.__version__)) + print("\tTorchvision: {}".format(torchvision.__version__)) + print("\tCUDA: {}".format(torch.version.cuda)) + print("\tCUDNN: {}".format(torch.backends.cudnn.version())) + print("\tNumPy: {}".format(np.__version__)) + print("\tPIL: {}".format(PIL.__version__)) + + print('Args:') + for k, v in sorted(vars(args).items()): + print('\t{}: {}'.format(k, v)) + + if args.hparams_seed == 0: + hparams = hparams_registry.default_hparams( + args.algorithm, args.dataset) + else: + hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, + misc.seed_hash(args.hparams_seed, args.trial_seed)) + if args.hparams: + hparams.update(json.loads(args.hparams)) + + print('HParams:') + for k, v in sorted(hparams.items()): + print('\t{}: {}'.format(k, v)) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + if args.dataset in vars(datasets): + dataset = vars(datasets)[args.dataset](args.data_dir, + args.test_envs, hparams) + else: + raise NotImplementedError + + # Split each env into an 'in-split' and an 'out-split'. We'll train on + # each in-split except the test envs, and evaluate on all splits. + + # To allow unsupervised domain adaptation experiments, we split each test + # env into 'in-split', 'uda-split' and 'out-split'. The 'in-split' is used + # by collect_results.py to compute classification accuracies. The + # 'out-split' is used by the Oracle model selectino method. The unlabeled + # samples in 'uda-split' are passed to the algorithm at training time if + # args.task == "domain_adaptation". If we are interested in comparing + # domain generalization and domain adaptation results, then domain + # generalization algorithms should create the same 'uda-splits', which will + # be discared at training. + in_splits = [] + out_splits = [] + uda_splits = [] + for env_i, env in enumerate(dataset): + uda = [] + + out, in_ = misc.split_dataset(env, + int(len(env) * args.holdout_fraction), + misc.seed_hash(args.trial_seed, env_i)) + + if env_i in args.test_envs: + uda, in_ = misc.split_dataset(in_, + int(len(in_) * args.uda_holdout_fraction), + misc.seed_hash(args.trial_seed, env_i)) + + if hparams['class_balanced']: + in_weights = misc.make_weights_for_balanced_classes(in_) + out_weights = misc.make_weights_for_balanced_classes(out) + if uda is not None: + uda_weights = misc.make_weights_for_balanced_classes(uda) + else: + in_weights, out_weights, uda_weights = None, None, None + in_splits.append((in_, in_weights)) + out_splits.append((out, out_weights)) + if len(uda): + uda_splits.append((uda, uda_weights)) + + train_loaders = [InfiniteDataLoader( + dataset=env, + weights=env_weights, + batch_size=hparams['batch_size'], + num_workers=dataset.N_WORKERS) + for i, (env, env_weights) in enumerate(in_splits) + if i not in args.test_envs] + + uda_loaders = [InfiniteDataLoader( + dataset=env, + weights=env_weights, + batch_size=hparams['batch_size'], + num_workers=dataset.N_WORKERS) + for i, (env, env_weights) in enumerate(uda_splits) + if i in args.test_envs] + + eval_loaders = [FastDataLoader( + dataset=env, + batch_size=256, + num_workers=dataset.N_WORKERS) + for env, _ in (in_splits + out_splits + uda_splits)] + eval_weights = [None for _, weights in ( + in_splits + out_splits + uda_splits)] + eval_loader_names = ['env{}_in'.format(i) + for i in range(len(in_splits))] + eval_loader_names += ['env{}_out'.format(i) + for i in range(len(out_splits))] + eval_loader_names += ['env{}_uda'.format(i) + for i in range(len(uda_splits))] + + algorithm_class = algorithms.get_algorithm_class(args.algorithm) + algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, + len(dataset) - len(args.test_envs), hparams) + + if algorithm_dict is not None: + algorithm.load_state_dict(algorithm_dict) + + algorithm.to(device) + + train_minibatches_iterator = zip(*train_loaders) + uda_minibatches_iterator = zip(*uda_loaders) + checkpoint_vals = collections.defaultdict(lambda: []) + + steps_per_epoch = min([len(env) / hparams['batch_size'] + for env, _ in in_splits]) + + n_steps = args.steps or dataset.N_STEPS + checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ + + + def save_checkpoint(filename): + import copy + cpa = copy.deepcopy(algorithm) + if args.skip_model_save: + return + save_dict = { + "args": vars(args), + "model_input_shape": dataset.input_shape, + "model_num_classes": dataset.num_classes, + "model_num_domains": len(dataset) - len(args.test_envs), + "model_hparams": hparams, + "model_dict": cpa.cpu().state_dict() + } + del cpa + torch.save(save_dict, os.path.join(args.output_dir, filename)) + + + last_results_keys = None + + for step in range(start_step, n_steps): + step_start_time = time.time() + minibatches_device = [(x.to(device), y.to(device)) + for x, y in next(train_minibatches_iterator)] + # print(len(minibatches_device)) + if args.task == "domain_adaptation": + uda_device = [x.to(device) + for x, _ in next(uda_minibatches_iterator)] + else: + uda_device = None + + if args.algorithm == 'CutERM' and step==hparams['cut_step']and step not in [0,n_steps-1]: + if hparams['cut_percent'] > 1e-8: + datas = feature_extractor_for_train(algorithm, zip( + eval_loader_names, eval_loaders), device, dataset.num_classes) + env_list = ['env{}'.format(i) for i in range(len(dataset))] + train_env = copy.deepcopy(env_list) + for ev in args.test_envs: + train_env.remove('env{}'.format(ev)) + feature_num = 512 if hparams['resnet18'] else 2048 + opt_for_train = moft.opt_kde(env_list, train_env, dataset.num_classes, + feature_num, datas, percent=hparams['cut_percent'], + sample_size=1000, device=device) + opt_for_train.pca() + trans_matrix = opt_for_train.forward(cal_info=True, set_mask=True) + else: + feature_num = 512 if hparams['resnet18'] else 2048 + trans_matrix = torch.eye(feature_num,device=device) + algorithm.update_classifer(trans_matrix) + + + step_vals = algorithm.update(minibatches_device, uda_device) + checkpoint_vals['step_time'].append(time.time() - step_start_time) + + for key, val in step_vals.items(): + checkpoint_vals[key].append(val) + if ((step % checkpoint_freq == 0) or (step == n_steps - 1)) and (step > args.start_step): + results = { + 'step': step, + 'epoch': step / steps_per_epoch, + } + + for key, val in checkpoint_vals.items(): + results[key] = np.mean(val) + + evals = zip(eval_loader_names, eval_loaders, eval_weights) + for name, loader, weights in evals: + start = time.time() + loaderlen = len(loader) + acc = misc.accuracy(algorithm, loader, weights, device) + # print("eavl " + name + " with loader len " + str(loaderlen) + " use time " + str(round(time.time()-start,3))) + results[name + '_acc'] = acc + + results_keys = sorted(results.keys()) + if results_keys != last_results_keys: + misc.print_row(results_keys, colwidth=12) + last_results_keys = results_keys + misc.print_row([results[key] for key in results_keys], + colwidth=12) + if args.dataset == 'ColoredMNIST': + store_list = ['env0_in_acc','env0_out_acc', 'env1_in_acc', + 'env1_out_acc','env2_in_acc', + 'env2_out_acc'] + else: + store_list = ['env0_in_acc','env0_out_acc', 'env1_in_acc', + 'env1_out_acc','env2_in_acc', + 'env2_out_acc','env3_in_acc','env3_out_acc'] + + if args.output_result_file is not None: + # print('enter this step') + assert args.extract_feature is not None + if title_flag and not os.path.exists(args.output_dir +'/'+ args.output_result_file): + title = "name" + for key in store_list: + title += ',' + key + for thr in threshold_list: + title += ',train_dis_{},test_dis_{},train_info_{},feature_num_{}'.format(thr,thr,thr,thr) + title += '\n' + with open(args.output_dir +'/'+ args.output_result_file, 'a+') as f: + f.write(title) + print('csv file created') + title = "name" + for key in store_list: + title += ',' + key + for thr in threshold_list: + title += ',train_dis_{},test_dis_{},train_info_{},feature_num_{}'.format(thr, thr, thr,thr) + title += '\n' + with open(args.output_dir + '/' + 'before_'+args.output_result_file, 'a+') as f: + f.write(title) + print('before csv file created') + title_flag = False + + with open(args.output_dir +'/'+ args.output_result_file, 'a+') as f: + res = args.extract_feature + '_step_{},'.format(step) + \ + str([results[key] for key in store_list])[1:-1] + ',' + if not args.follow_plot: + res += '\n' + f.write(res) + with open(args.output_dir +'/'+ 'before_'+args.output_result_file, 'a+') as f: + res = args.extract_feature + '_step_{},'.format(step) + \ + str([results[key] for key in store_list])[1:-1] + ',' + if not args.follow_plot: + res += '\n' + f.write(res) + + results.update({ + 'hparams': hparams, + 'args': vars(args) + }) + + epochs_path = os.path.join(args.output_dir, 'results.jsonl') + with open(epochs_path, 'a') as f: + f.write(json.dumps(results, sort_keys=True) + "\n") + + start_step = step + 1 + checkpoint_vals = collections.defaultdict(lambda: []) + + if args.save_feature_every_checkpoint: + if args.extract_feature is not None: + #print('________start feature extract_________') + if args.output_dir[-1] == '/': + marker = args.output_dir + "extracted_{}".format(step) + else: + marker = args.output_dir + "/" + "extracted_{}".format(step) + datas = feature_extractor_for_pipline(algorithm, zip( + eval_loader_names, eval_loaders), device, dataset.num_classes, marker,val=args.val) + env_list = ['env{}'.format(i) for i in range(len(dataset))] + train_env = copy.deepcopy(env_list) + for ev in args.test_envs: + train_env.remove('env{}'.format(ev)) + if args.dataset == 'ColoredMNIST': + feature_num = 128 + else: + feature_num = 512 if hparams['resnet18'] else 2048 + opt_for_pipline = opt_kde(env_list, train_env, dataset.num_classes, + feature_num, datas, sample_size=10000, device=device) + compute_result = torch_to_numpy( + opt_for_pipline.forward(cal_info=True, use_mean_info=True)) + compute_result['eig_value'] = opt_for_pipline.eig_val() + + mmstr = '_mean' + new_for_save = np.array(compute_result) + np.save(marker + "before_new_L1_" + mmstr + "_save.npy", new_for_save) + del new_for_save + train_distance = compute_result['train_results'].max(axis=0) + test_distance = compute_result['test_results'].max(axis=0) + info = compute_result['train_info'] + print("———————— before info filter ————————") + print("train_dis:", train_distance) + print("test_dis:", test_distance) + print("info:", info) + line = '' + for thr in threshold_list: + select_index = [i for i in range(len(info)) if info[i] >= thr] + # print(select_index) + if len(select_index) == 0: + train_mean = float('nan') + test_mean = float('nan') + info_mean = float('nan') + line += to_str([train_mean, test_mean, info_mean,0]) + else: + train_mean = train_distance[select_index].mean() + test_mean = test_distance[select_index].mean() + info_mean = info[select_index].mean() + line += to_str([train_mean, test_mean, info_mean,len(select_index)]) + line += '\n' + del compute_result + + if args.output_result_file is not None: + with open(args.output_dir + '/' + 'before_' + args.output_result_file, 'a+') as f: + f.write(line) + + for _ in range(4000): + opt_for_pipline.backward(backward_method='mean',lr=1.0) + compute_result = torch_to_numpy( + opt_for_pipline.forward(cal_info=True, use_mean_info=True)) + compute_result['eig_value'] = opt_for_pipline.eig_val() + + mmstr = '_mean' + new_for_save = np.array(compute_result) + np.save(marker + "new_L1_" + mmstr + "_save.npy", new_for_save) + del new_for_save + + train_distance = compute_result['train_results'].max(axis=0) + test_distance = compute_result['test_results'].max(axis=0) + info = compute_result['train_info'] + print("———————— info filter ————————") + print("train_dis:",train_distance) + print("test_dis:",test_distance) + print("info:",info) + line = '' + for thr in threshold_list: + select_index = [i for i in range(len(info)) if info[i] >= thr] + #print(select_index) + if len(select_index) == 0: + train_mean = float('nan') + test_mean = float('nan') + info_mean = float('nan') + line += to_str([train_mean, test_mean, info_mean,0]) + else: + train_mean = train_distance[select_index].mean() + test_mean = test_distance[select_index].mean() + info_mean = info[select_index].mean() + line += to_str([train_mean, test_mean, info_mean,len(select_index)]) + line += '\n' + if args.output_result_file is not None: + with open(args.output_dir +'/'+ args.output_result_file, 'a+') as f: + f.write(line) + del opt_for_pipline + + if not args.skip_model_save: + save_checkpoint(f'model_step{step}.pkl') + if not args.skip_model_save: + save_checkpoint('model.pkl') + + with open(os.path.join(args.output_dir, 'done'), 'w') as f: + f.write('done') + diff --git a/my_launcher.py b/my_launcher.py new file mode 100644 index 000000000..83a1a8708 --- /dev/null +++ b/my_launcher.py @@ -0,0 +1,305 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +""" +Run sweeps +""" + +import argparse +import copy +import getpass +import hashlib +import json +import os +import random +import shutil +import time +import uuid + +import numpy as np +import torch + + +from domainbed import datasets +from domainbed import hparams_registry +from domainbed import algorithms +from domainbed.lib import misc +from domainbed import command_launchers + +import tqdm +import shlex +import itertools + +class Job: + NOT_LAUNCHED = 'Not launched' + INCOMPLETE = 'Incomplete' + DONE = 'Done' + + def __init__(self, train_args): + self.output_dir = train_args['output_dir'] + self.train_args = copy.deepcopy(train_args) + self.extract = self.output_dir + '/' + train_args['extract_feature'] + command = ['python', 'main.py'] + for k, v in sorted(self.train_args.items()): + if v == '': + command.append(f'--{k} ') + continue + if isinstance(v, list): + v = ' '.join([str(v_) for v_ in v]) + elif isinstance(v, str): + v = shlex.quote(v) + command.append(f'--{k} {v}') + self.command_str = ' '.join(command) + + if os.path.exists(os.path.join(self.extract, 'done')): + self.state = Job.DONE + elif os.path.exists(self.extract): + self.state = Job.INCOMPLETE + else: + self.state = Job.NOT_LAUNCHED + + def __str__(self): + job_info = (self.train_args['dataset'], + self.train_args['algorithm'], + self.train_args['test_envs'],self.train_args['extract_feature']) + return '{}: {} {}'.format( + self.state, + self.extract, + job_info) + + @staticmethod + def launch(jobs, launcher_fn,available_list=[0,1,2,3]): + print('Launching...') + jobs = jobs.copy() + #np.random.shuffle(jobs) + print('Making job directories:') + for job in tqdm.tqdm(jobs, leave=False): + os.makedirs(job.output_dir, exist_ok=True) + commands = [job.command_str for job in jobs] + launcher_fn(commands,available_list) + print(f'Launched {len(jobs)} jobs!') + + @staticmethod + def delete(jobs): + print('Deleting...') + for job in jobs: + if os.path.isdir(job.extract): + shutil.rmtree(job.extract) + print(f'Deleted {len(jobs)} jobs!') + +def all_test_env_combinations(n): + """ + For a dataset with n >= 3 envs, return all combinations of 1 and 2 test + envs. + """ + assert(n >= 3) + for i in range(n): + yield [i] + for j in range(i+1, n): + yield [i, j] + +def make_args_list(n_trials, dataset_names, algorithms, n_hparams_from, n_hparams, steps, + data_dir, task, holdout_fraction, single_test_envs, hparams): + args_list = [] + for trial_seed in range(n_trials): + for dataset in dataset_names: + for algorithm in algorithms: + if single_test_envs: + all_test_envs = [ + [i] for i in range(datasets.num_environments(dataset))] + else: + all_test_envs = all_test_env_combinations( + datasets.num_environments(dataset)) + for test_envs in all_test_envs: + for hparams_seed in range(n_hparams_from, n_hparams): + train_args = {} + train_args['dataset'] = dataset + train_args['algorithm'] = algorithm + train_args['test_envs'] = test_envs + train_args['holdout_fraction'] = holdout_fraction + train_args['hparams_seed'] = hparams_seed + train_args['data_dir'] = data_dir + train_args['task'] = task + train_args['trial_seed'] = trial_seed + train_args['seed'] = misc.seed_hash(dataset, + algorithm, test_envs, hparams_seed, trial_seed) + if steps is not None: + train_args['steps'] = steps + if hparams is not None: + train_args['hparams'] = hparams + args_list.append(train_args) + return args_list + +def ask_for_confirmation(): + response = input('Are you sure? (y/n) ') + if not response.lower().strip()[:1] == "y": + print('Nevermind!') + exit(0) + +DATASETS = [d for d in datasets.DATASETS if "Debug" not in d] + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run a sweep') + parser.add_argument('command', choices=['launch', 'delete_incomplete','just_view']) + parser.add_argument('--datasets', nargs='+', type=str, default=DATASETS) + parser.add_argument('--algorithms', nargs='+', type=str, default=algorithms.ALGORITHMS) + parser.add_argument('--task', type=str, default="domain_generalization") + parser.add_argument('--n_hparams_from', type=int, default=0) + parser.add_argument('--n_hparams', type=int, default=20) + #parser.add_argument('--output_dir', type=str, required=True) + #parser.add_argument('--data_dir', type=str, required=True) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--n_trials', type=int, default=3) + parser.add_argument('--command_launcher', type=str, required=True) + parser.add_argument('--steps', type=int, default=None) + parser.add_argument('--hparams', type=str, default=None) + parser.add_argument('--holdout_fraction', type=float, default=0.2) + parser.add_argument('--single_test_envs', action='store_true') + parser.add_argument('--skip_confirmation', action='store_true') + parser.add_argument('--train_algorithm',type=str, default='') + args = parser.parse_args() + debug = False + + available_list = [1,2,3] + algorithm_dict = {} + # test_env0 + # ERM Mixup in 171 + # other on ctl + + # test_env1 + + # mixup 1e-4, 3e-4,working on 171, + # VREx 1e-4, working on 171 + # ERM,mixup 5e-5 working on 175 + # other on 173 + + # test_env2 + # all in 173 + # + + # test_env3 + # ERM, Mixup,in ctl4 + # GroupDRO,IRM ? in ctl3 + # VREx in 173 + + # + # algorithm_dict['ERM'] = {'times': 5, 'hparam': {'lr': [1e-4, 3e-4, 5e-4]}, + # 'start_step':0} + # algorithm_dict['Mixup'] = {'times': 5, 'hparam': {'lr': [1e-4, 5e-5], + # 'mixup_alpha': [0.1, 0.2]}, + # 'start_step': 0,'freq':2500} + # algorithm_dict['GroupDRO'] = {'times':5,'hparam':{'lr':[1e-4,5e-5], + # 'groupdro_eta':[0.01,0.1]}, + # 'start_step':0,'freq':2500} + # algorithm_dict['IRM'] = {'times':5, + # 'hparam':{'lr':[1e-4], + # 'irm_penalty_anneal_iters':[1000],'irm_lambda':[1,10]}, + # 'start_step':1000,'freq':2500} + # algorithm_dict['VREx'] = {'times':5, + # 'hparam':{'lr':[1e-4,3e-4,5e-4], + # 'vrex_anneal_iter':[1000],'vrex_lambda':[1,10,100,1000]}, + # 'start_step':1000} + + algorithm_dict['CORAL'] = {'times':5, + 'hparam':{'lr':[1e-4,5e-5], + 'mmd_gamma':[0.01,0.1,1,10]}, + 'start_step':0,'freq':500} + # + # algorithm_dict['ERM'] = {'times': 5, 'hparam': {'lr': [1e-4,5e-5]}, + # 'start_step':0,'freq':2500} + + if args.train_algorithm: + algorithm_dict_tmp = {} + args.train_algorithm = args.train_algorithm.split(',') + for algorithm in args.train_algorithm: + algorithm_dict_tmp[algorithm] = copy.deepcopy(algorithm_dict[algorithm]) + algorithm_dict = algorithm_dict_tmp + + val_test = False + if val_test: + algorithm_dict = {} + algorithm_dict['ERM'] = {'times': 1, 'hparam': {'lr': [1e-4, 3e-4, 5e-4]}, + 'start_step':0,} + + args_list = [] + dataset_list = ['ColoredMNIST'] + + for data_set in dataset_list: + if data_set == 'OfficeHome': + test_env_list = [0] + elif data_set == 'VLCS': + test_env_list = [0,1,2,3] + elif data_set == 'PACS': + test_env_list = [0,1,2,3] + elif data_set == 'ColoredMNIST': + test_env_list = [2] + for test_env in test_env_list: + for alg in algorithm_dict: + hparams = {} + train_args = {} + train_args['algorithm'] = alg + if os.path.exists('domainbed/{}'.format(data_set)): + train_args['data_dir'] = 'domainbed' + else: + train_args['data_dir'] = 'domainbed/datasets' + train_args['algorithm'] = alg + train_args['dataset'] = data_set + train_args['test_envs'] = test_env + train_args['steps'] = 1001 if data_set == 'ColoredMNIST' else 5001 + train_args['start_step'] = algorithm_dict[alg]['start_step'] + train_args['output_dir'] = 'logs/{}_{}_test_env{}'.format(data_set,alg, test_env) if not val_test else \ + 'logs/val_test_{}_test_env{}'.format(alg, test_env) + if val_test: + train_args['val'] = 'out' + train_args['output_result_file'] = 'result.csv' + if 'freq' in algorithm_dict[alg]: + train_args['checkpoint_freq'] = algorithm_dict[alg]['freq'] + else: + train_args['checkpoint_freq'] = (train_args['steps'] - train_args['start_step'] - 1) // 10 + train_args['save_feature_every_checkpoint'] = '' + train_args['skip_model_save'] = '' + train_args['follow_plot'] = '' + param_iter = itertools.product(*list(algorithm_dict[alg]['hparam'].values())) + para_title = algorithm_dict[alg]['hparam'].keys() + for para_comb in param_iter: + # exp_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + for i, key in enumerate(para_title): + hparams[key] = para_comb[i] + raw_hparams = json.dumps(hparams) + hparam_str = raw_hparams.replace('.', '*').replace(':', '=').replace(' ', '').replace(',', '_') + full_times = algorithm_dict[alg]['times'] + train_args['hparams'] = raw_hparams + for times in range(full_times): + file_name = "%s_%s_%s" % (alg, hparam_str, times) if not debug else \ + "debug_%s_%s_%s" % (alg, hparam_str, times) + train_args['extract_feature'] = file_name + train_args['trial_seed'] = random.randint(100000, 999999) + args_list.append(copy.deepcopy(train_args)) + + + jobs = [Job(train_args) for train_args in args_list] + + for job in jobs: + print(job) + print("{} jobs: {} done, {} incomplete, {} not launched.".format( + len(jobs), + len([j for j in jobs if j.state == Job.DONE]), + len([j for j in jobs if j.state == Job.INCOMPLETE]), + len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) + ) + + if args.command == 'launch': + to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] + print(f'About to launch {len(to_launch)} jobs.') + if not args.skip_confirmation: + ask_for_confirmation() + launcher_fn = command_launchers.REGISTRY[args.command_launcher] + Job.launch(to_launch, launcher_fn,available_list) + + elif args.command == 'delete_incomplete': + to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] + print(f'About to delete {len(to_delete)} jobs.') + if not args.skip_confirmation: + ask_for_confirmation() + Job.delete(to_delete) + elif args.command == 'just_view': + pass \ No newline at end of file