| from statistics import median |
| from skimage.metrics import structural_similarity |
|
|
| def getSSIM(gt, out, gt_flag=None, data_range=1): |
| if gt_flag is None: |
| gt_flag = [True]*gt.shape[0] |
|
|
| vals = [] |
| for i in range(gt.shape[0]): |
| if not gt_flag[i]: |
| continue |
| vals.extend( |
| structural_similarity( |
| gt[i, j, ...], out[i, j, ...], data_range=data_range |
| ) |
| for j in range(gt.shape[1]) |
| ) |
| return median(vals) |
|
|
| def ema(source, target, decay): |
| source_dict = source.state_dict() |
| target_dict = target.state_dict() |
| for key in source_dict.keys(): |
| target_dict[key].data.copy_(target_dict[key].data * decay + |
| source_dict[key].data * (1 - decay)) |
|
|
|
|
| class WarmupLR: |
| def __init__(self, warmup) -> None: |
| self.warmup = warmup |
|
|
| def __call__(self, step): |
| return min(step, self.warmup) / self.warmup |