Skip to content

Code errors in swad.py #28

@wangtz19

Description

@wangtz19

There are some implementation errors in the update_and_evaluate function:

swad/domainbed/swad.py

Lines 83 to 123 in 252190e

def update_and_evaluate(self, segment_swa, val_acc, val_loss, prt_fn):
if self.dead_valley:
return
frozen = copy.deepcopy(segment_swa.cpu())
frozen.end_loss = val_loss
self.converge_Q.append(frozen)
self.smooth_Q.append(frozen)
if not self.is_converged:
if len(self.converge_Q) < self.n_converge:
return
min_idx = np.argmin([model.end_loss for model in self.converge_Q])
untilmin_segment_swa = self.converge_Q[min_idx] # until-min segment swa.
if min_idx == 0:
self.converge_step = self.converge_Q[0].end_step
self.final_model = swa_utils.AveragedModel(untilmin_segment_swa)
th_base = np.mean([model.end_loss for model in self.converge_Q])
self.threshold = th_base * (1.0 + self.tolerance_ratio)
if self.n_tolerance < self.n_converge:
for i in range(self.n_converge - self.n_tolerance):
model = self.converge_Q[1 + i]
self.final_model.update_parameters(
model, start_step=model.start_step, end_step=model.end_step
)
elif self.n_tolerance > self.n_converge:
converge_idx = self.n_tolerance - self.n_converge
Q = list(self.smooth_Q)[: converge_idx + 1]
start_idx = 0
for i in reversed(range(len(Q))):
model = Q[i]
if model.end_loss > self.threshold:
start_idx = i + 1
break
for model in Q[start_idx + 1 :]:
self.final_model.update_parameters(
model, start_step=model.start_step, end_step=model.end_step
)

When min_idx == 0 is satisfied, the final_model is initialized by converge_Q[0]. For the self.n_tolerance > self.n_converge branch, Q = list(self.smooth_Q)[: converge_idx + 1] ensures that Q[-1] == converge_Q[0], thus the following updating for-loop for model in Q[start_idx + 1 :]: leads to two errors:
(1) converge_Q[0] is updated to final_model twice
(2) the left-most weight in Q whose end_loss <= self.threshold is omitted due to the starting index of start_idx + 1

A refined version is below:

converge_idx = self.n_tolerance - self.n_converge
Q = list(self.smooth_Q)[: converge_idx]  # excludes converge_Q[0],Q only contains those segments before converge_Q[0]
start_idx = 0
for i in reversed(range(len(Q))):
    if Q[i].end_loss > self.threshold:
        start_idx = i + 1
        break
for model in Q[start_idx:]:
    self.final_model.update_parameters(
        model, start_step=model.start_step, end_step=model.end_step
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions