-
Notifications
You must be signed in to change notification settings - Fork 6
MUSIQ metric needs to be added with a PyTorch module #13
Copy link
Copy link
Open
Description
I used a TensorFlow model for the MUSIQ metric which is difficult to install. We should find a better one with PyTorch and update this part of the code to work properly.
nerfiller/nerfiller/utils/metrics.py
Lines 173 to 185 in e017c6e
| # scores = [] | |
| # for i in range(bs): | |
| # image = preds[i].permute(1, 2, 0) | |
| # img = Image.fromarray((image.detach().cpu().numpy() * 255).astype("uint8")) | |
| # image_bytes = io.BytesIO() | |
| # img.save(image_bytes, format="PNG") | |
| # image_bytes = image_bytes.getvalue() | |
| # x = tf.constant(image_bytes) | |
| # assert x.device.endswith("GPU:0") | |
| # aesthetic_score = self.predict_fn(x) | |
| # score = float(tf.squeeze(aesthetic_score["output_0"]).numpy()) | |
| # scores.append(score) | |
| # scores = torch.tensor(scores) |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels