Xinyi Wang commited on
Commit ·
db25ead
0
Parent(s):
project files
Browse files- .gitignore +3 -0
- README.md +198 -0
- metadata/KVQ_metadata.csv +0 -0
- metadata/LSVQ_TEST_1080P_metadata.csv +0 -0
- metadata/LSVQ_TEST_metadata.csv +0 -0
- metadata/LSVQ_TRAIN_metadata.csv +0 -0
- metadata/SHORTS-HDR2SDR-DATASET_metadata.csv +0 -0
- metadata/SHORTS-SDR-DATASET_metadata.csv +0 -0
- requirements.txt +75 -0
- src/correlation_result.ipynb +914 -0
- src/demo_test.py +264 -0
- src/model/__init__.py +1 -0
- src/model/clip_dense_encoder.py +61 -0
- src/model/qd_model.py +178 -0
- src/module/__init__.py +1 -0
- src/module/compute_weight_map.py +144 -0
- src/module/frequency_dct.py +351 -0
- src/module/read_frame_decord.py +138 -0
- src/train.py +604 -0
- src/transfer.py +291 -0
- src/transfer_test_only.py +284 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.gitignore
|
| 2 |
+
.DS_Store
|
| 3 |
+
test_videos/
|
README.md
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FGSVQA
|
| 2 |
+

|
| 3 |
+

|
| 4 |
+

|
| 5 |
+
[](http://arxiv.org/abs/2605.20016)
|
| 6 |
+
|
| 7 |
+
Official Code for the following paper:
|
| 8 |
+
|
| 9 |
+
**X. Wang, A. Katsenou, J.Shen and D. Bull**. [FGSVQA: Frequency-Guided Short-form Video Quality Assessment](http://arxiv.org/abs/2605.20016)
|
| 10 |
+
|
| 11 |
+
[Our paper]() was accepted by the 18th International Conference on Quality of Multimedia Experience ([QoMEX 2026](https://qomex2026.itec.aau.at/)).
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Performance
|
| 16 |
+
We validated our proposed method on two publicly available Short-form UGC datasets: KVQ and YouTube SFV+HDR dataset (YT-SFV).
|
| 17 |
+
|
| 18 |
+
#### **Spearman’s Rank Correlation Coefficient (SRCC)**
|
| 19 |
+
| **Model** | **KVQ** | **YT-SFV (SDR)** | **YT-SFV (HDR2SDR)** |
|
| 20 |
+
|----------------------------|-----------|------------------|----------------------|
|
| 21 |
+
| FGSVQA | 0.877 | 0.788 | 0.543 |
|
| 22 |
+
|
| 23 |
+
#### **Pearson’s Linear Correlation Coefficient (PLCC)**
|
| 24 |
+
| **Model** | **KVQ** | **YT-SFV (SDR)** | **YT-SFV (HDR2SDR)** |
|
| 25 |
+
|----------------------------|-----------|------------------|----------------------|
|
| 26 |
+
| FGSVQA | 0.878 | 0.818 | 0.666 |
|
| 27 |
+
|
| 28 |
+
#### **GPU runtime comparison (averaged over 10 runs) across different spatial resolutions on "SDR\_Animal\_5ngj.mp4".**
|
| 29 |
+
| Method | Time(s)<br>540P | Time(s)<br>720P | Time(s)<br>1080P | Time(s)<br>2160P | Ground truth: 4.308<br>Predicted Score|
|
| 30 |
+
|---|------------:|------------:|-------------:|---:|---:|
|
| 31 |
+
| Fast-VQA | 0.599 | 0.673 | 0.909 | 2.217 | 3.319 |
|
| 32 |
+
| FasterVQA | 0.489 | 0.547 | **0.696** | **1.343** | 3.556 |
|
| 33 |
+
| DOVER | 0.920 | 1.022 | 1.293 | 2.783 | 3.814 |
|
| 34 |
+
| FGSVQA | **0.313** | **0.405** | 0.697 | 2.137 | **3.878** |
|
| 35 |
+
|
| 36 |
+
More results can be found in **[correlation_result.ipynb](https://github.com/xinyiW915/FGSVQA/blob/main/src/correlation_result.ipynb)**.
|
| 37 |
+
|
| 38 |
+
## Proposed Model
|
| 39 |
+
Overview of the proposed model with the two branches: the frequency-guided weight map and the CLIP vision encoder.
|
| 40 |
+
|
| 41 |
+
<img src="./SVQA.png" alt="proposed_FGSVQA_framework" width="800"/>
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
### 📌 Install Requirement
|
| 45 |
+
The repository is built with **Python 3.10** and can be installed via the following commands:
|
| 46 |
+
|
| 47 |
+
```shell
|
| 48 |
+
git clone https://github.com/xinyiW915/FGSVQA.git
|
| 49 |
+
cd FGSVQA
|
| 50 |
+
conda create -n fgsvqa python=3.10 -y
|
| 51 |
+
conda activate fgsvqa
|
| 52 |
+
pip install -r requirements.txt
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### 📥 Download UGC Datasets
|
| 56 |
+
The corresponding UGC video datasets can be downloaded from the following sources:
|
| 57 |
+
[KVQ](https://lixinustc.github.io/projects/KVQ/), [YouTube SFV+HDR](https://media.withyoutube.com/sfv-hdr).
|
| 58 |
+
|
| 59 |
+
The metadata for the experimented UGC dataset is available under [`./metadata`](./metadata).
|
| 60 |
+
|
| 61 |
+
### 🎬 Test Demo
|
| 62 |
+
Run the pre-trained model to evaluate the perceptual quality of a single video. The demo script reports the predicted quality score, runtime, and model complexity.
|
| 63 |
+
|
| 64 |
+
The model checkpoint should be provided through `--ckpt_path`. Please use a full checkpoint file, such as `qd_model.best.pt`, which contains the saved model weights together with the training MOS mean and standard deviation.
|
| 65 |
+
|
| 66 |
+
To evaluate a single video, run:
|
| 67 |
+
```shell
|
| 68 |
+
python demo_test.py \
|
| 69 |
+
--ckpt_path <MODEL_PATH> \
|
| 70 |
+
--db_path <VIDEO_FOLDER> \
|
| 71 |
+
--video_id <VIDEO_ID> \
|
| 72 |
+
--device <DEVICE>
|
| 73 |
+
````
|
| 74 |
+
For example:
|
| 75 |
+
```shell
|
| 76 |
+
python demo_test.py \
|
| 77 |
+
--ckpt_path ./checkpoints/lsvq/qd_model.best.pt \
|
| 78 |
+
--db_path ./test_videos/ \
|
| 79 |
+
--video_id SDR_Animal_5ngj \
|
| 80 |
+
--device cuda
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### 🔁 Cross-Dataset Evaluation
|
| 84 |
+
To evaluate a trained model on another dataset, use `transfer_test_only.py`. This script loads a trained checkpoint, reports the evaluation metrics, and saves the prediction results to a CSV file.
|
| 85 |
+
|
| 86 |
+
Run:
|
| 87 |
+
```shell
|
| 88 |
+
python transfer_test_only.py \
|
| 89 |
+
--ckpt_path <MODEL_PATH> \
|
| 90 |
+
--csv_path <TEST_METADATA_CSV> \
|
| 91 |
+
--db_path <TEST_VIDEO_FOLDER> \
|
| 92 |
+
--device <DEVICE> \
|
| 93 |
+
--save_pred_csv <SAVE_PREDICTION_CSV>
|
| 94 |
+
```
|
| 95 |
+
For example:
|
| 96 |
+
```shell
|
| 97 |
+
python transfer_test_only.py \
|
| 98 |
+
--ckpt_path ./checkpoints/lsvq/qd_model.best.pt \
|
| 99 |
+
--csv_path ./metadata/KVQ_metadata.csv \
|
| 100 |
+
--db_path /path/to/KVQ/videos \
|
| 101 |
+
--device cuda \
|
| 102 |
+
--save_pred_csv /path/to/transfer_test_only_konvid_1k.csv
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Training
|
| 106 |
+
Steps to train and fine-tune the model on different datasets.
|
| 107 |
+
|
| 108 |
+
### Train Model
|
| 109 |
+
Train the model using the metadata CSV file and the corresponding video folder. The metadata CSV file should contain `vid` and `mos` columns.
|
| 110 |
+
|
| 111 |
+
```shell
|
| 112 |
+
python train.py \
|
| 113 |
+
--csv_path <TRAIN_METADATA_CSV> \
|
| 114 |
+
--db_path <VIDEO_FOLDER> \
|
| 115 |
+
--save_dir <SAVE_DIR> \
|
| 116 |
+
--save_name qd_model.pt \
|
| 117 |
+
--device <DEVICE> \
|
| 118 |
+
--finetune_last_stage
|
| 119 |
+
```
|
| 120 |
+
For example:
|
| 121 |
+
```shell
|
| 122 |
+
python train.py \
|
| 123 |
+
--csv_path ./metadata/KVQ_TRAIN_metadata.csv \
|
| 124 |
+
--db_path /path/to/KVQ/videos \
|
| 125 |
+
--save_dir ./checkpoints/kvq \
|
| 126 |
+
--save_name qd_model.pt \
|
| 127 |
+
--device cuda \
|
| 128 |
+
--finetune_last_stage
|
| 129 |
+
```
|
| 130 |
+
The script saves the latest checkpoint and the best-performing checkpoint according to the validation SRCC.
|
| 131 |
+
|
| 132 |
+
### Transfer Model
|
| 133 |
+
To fine-tune a pre-trained model on a new dataset, run:
|
| 134 |
+
|
| 135 |
+
```shell
|
| 136 |
+
python transfer.py \
|
| 137 |
+
--mode finetune \
|
| 138 |
+
--pretrained <PRETRAINED_MODEL_PATH> \
|
| 139 |
+
--csv_path <TARGET_METADATA_CSV> \
|
| 140 |
+
--db_path <TARGET_VIDEO_FOLDER> \
|
| 141 |
+
--save_dir <SAVE_DIR> \
|
| 142 |
+
--save_name transfer.pt \
|
| 143 |
+
--device <DEVICE> \
|
| 144 |
+
--finetune_last_stage
|
| 145 |
+
```
|
| 146 |
+
For example:
|
| 147 |
+
```shell
|
| 148 |
+
python transfer.py \
|
| 149 |
+
--mode finetune \
|
| 150 |
+
--pretrained ./checkpoints/shorts-hdr-dataset_sdr/qd_model.best.pt \
|
| 151 |
+
--csv_path ./metadata/KVQ_TRAIN_metadata.csv \
|
| 152 |
+
--db_path /path/to/KVQ/videos \
|
| 153 |
+
--save_dir ./checkpoints_transfer/kvq \
|
| 154 |
+
--save_name transfer.pt \
|
| 155 |
+
--device cuda \
|
| 156 |
+
--finetune_last_stage
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Test Only
|
| 160 |
+
To directly test a pre-trained model on another dataset, run:
|
| 161 |
+
|
| 162 |
+
```shell
|
| 163 |
+
python transfer.py \
|
| 164 |
+
--mode test_only \
|
| 165 |
+
--pretrained <PRETRAINED_MODEL_PATH> \
|
| 166 |
+
--csv_path <TEST_METADATA_CSV> \
|
| 167 |
+
--db_path <TEST_VIDEO_FOLDER> \
|
| 168 |
+
--device <DEVICE>
|
| 169 |
+
```
|
| 170 |
+
For example:
|
| 171 |
+
```shell
|
| 172 |
+
python transfer.py \
|
| 173 |
+
--mode test_only \
|
| 174 |
+
--pretrained ./checkpoints/shorts-hdr-dataset_sdr/qd_model.best.pt \
|
| 175 |
+
--csv_path ./metadata/KVQ_metadata.csv \
|
| 176 |
+
--db_path /path/to/KVQ/videos \
|
| 177 |
+
--device cuda
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
## Acknowledgment
|
| 182 |
+
This work was funded by the UKRI MyWorld Strength in Places Programme (SIPF00006/1) as part of my PhD study.
|
| 183 |
+
|
| 184 |
+
## Citation
|
| 185 |
+
If you find this paper and the repo useful, please cite our paper 😊:
|
| 186 |
+
|
| 187 |
+
```bibtex
|
| 188 |
+
@article{wang2026fgsvqa,
|
| 189 |
+
title={FGSVQA: Frequency-Guided Short-form Video Quality Assessment},
|
| 190 |
+
author={Wang, Xinyi and Katsenou, Angeliki, Shen, Junxiao and Bull, David},
|
| 191 |
+
booktitle={2026 18th International Conference on Quality of Multimedia Experience (QoMEX)},
|
| 192 |
+
year={2026},
|
| 193 |
+
organization={IEEE}
|
| 194 |
+
}
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
## Contact:
|
| 198 |
+
Xinyi WANG, ```xinyi.wang@bristol.ac.uk```
|
metadata/KVQ_metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metadata/LSVQ_TEST_1080P_metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metadata/LSVQ_TEST_metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metadata/LSVQ_TRAIN_metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metadata/SHORTS-HDR2SDR-DATASET_metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
metadata/SHORTS-SDR-DATASET_metadata.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.12.0
|
| 2 |
+
annotated-doc==0.0.4
|
| 3 |
+
anyio==4.12.1
|
| 4 |
+
av==16.1.0
|
| 5 |
+
bitsandbytes==0.49.1
|
| 6 |
+
certifi==2026.1.4
|
| 7 |
+
click==8.3.1
|
| 8 |
+
contourpy==1.3.2
|
| 9 |
+
cuda-bindings==12.9.4
|
| 10 |
+
cuda-pathfinder==1.3.3
|
| 11 |
+
cycler==0.12.1
|
| 12 |
+
decord==0.6.0
|
| 13 |
+
einops==0.8.2
|
| 14 |
+
exceptiongroup==1.3.1
|
| 15 |
+
filelock==3.20.0
|
| 16 |
+
fonttools==4.61.1
|
| 17 |
+
fsspec==2025.12.0
|
| 18 |
+
h11==0.16.0
|
| 19 |
+
hf-xet==1.2.0
|
| 20 |
+
httpcore==1.0.9
|
| 21 |
+
httpx==0.28.1
|
| 22 |
+
huggingface_hub==1.4.1
|
| 23 |
+
idna==3.11
|
| 24 |
+
Jinja2==3.1.6
|
| 25 |
+
kiwisolver==1.4.9
|
| 26 |
+
mamba-ssm==2.3.0
|
| 27 |
+
markdown-it-py==4.0.0
|
| 28 |
+
MarkupSafe==2.1.5
|
| 29 |
+
matplotlib==3.10.8
|
| 30 |
+
mdurl==0.1.2
|
| 31 |
+
mpmath==1.3.0
|
| 32 |
+
networkx==3.4.2
|
| 33 |
+
ninja==1.13.0
|
| 34 |
+
numpy==2.2.6
|
| 35 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 36 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 37 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 38 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 39 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 40 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 41 |
+
nvidia-cufile-cu12==1.13.1.3
|
| 42 |
+
nvidia-curand-cu12==10.3.2.106
|
| 43 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 44 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 45 |
+
nvidia-cusparselt-cu12==0.7.1
|
| 46 |
+
nvidia-nccl-cu12==2.21.5
|
| 47 |
+
nvidia-nvjitlink-cu12==12.8.93
|
| 48 |
+
nvidia-nvshmem-cu12==3.4.5
|
| 49 |
+
nvidia-nvtx-cu12==12.1.105
|
| 50 |
+
opencv-python==4.13.0.92
|
| 51 |
+
packaging @ file:///home/task_176104877067765/conda-bld/packaging_1761049113113/work
|
| 52 |
+
pillow==12.1.1
|
| 53 |
+
psutil==7.2.2
|
| 54 |
+
Pygments==2.19.2
|
| 55 |
+
pyparsing==3.3.2
|
| 56 |
+
python-dateutil==2.9.0.post0
|
| 57 |
+
PyYAML==6.0.3
|
| 58 |
+
regex==2026.1.15
|
| 59 |
+
rich==14.3.2
|
| 60 |
+
safetensors==0.7.0
|
| 61 |
+
scipy==1.15.3
|
| 62 |
+
shellingham==1.5.4
|
| 63 |
+
six==1.17.0
|
| 64 |
+
sympy==1.13.1
|
| 65 |
+
tokenizers==0.22.2
|
| 66 |
+
torch==2.5.1+cu121
|
| 67 |
+
torchaudio==2.5.1+cu121
|
| 68 |
+
torchvision==0.20.1+cu121
|
| 69 |
+
tqdm==4.67.3
|
| 70 |
+
transformers==5.1.0
|
| 71 |
+
triton==3.1.0
|
| 72 |
+
typer==0.23.0
|
| 73 |
+
typer-slim==0.23.0
|
| 74 |
+
typing_extensions==4.15.0
|
| 75 |
+
xformers==0.0.34
|
src/correlation_result.ipynb
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"id": "initial_id",
|
| 6 |
+
"metadata": {
|
| 7 |
+
"collapsed": true,
|
| 8 |
+
"ExecuteTime": {
|
| 9 |
+
"end_time": "2026-05-19T12:42:02.280702Z",
|
| 10 |
+
"start_time": "2026-05-19T12:42:02.276769Z"
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"source": [
|
| 14 |
+
"from pathlib import Path\n",
|
| 15 |
+
"import re\n",
|
| 16 |
+
"import pandas as pd"
|
| 17 |
+
],
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"execution_count": 19
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"source": "Intra-dataset evaluation: Train on and test on the same <dataset>",
|
| 25 |
+
"id": "145ddc0eb7f59e45"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"metadata": {
|
| 29 |
+
"ExecuteTime": {
|
| 30 |
+
"end_time": "2026-05-19T12:42:02.291070Z",
|
| 31 |
+
"start_time": "2026-05-19T12:42:02.288978Z"
|
| 32 |
+
}
|
| 33 |
+
},
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"source": "LOG_ROOT = Path(\"./checkpoints\")",
|
| 36 |
+
"id": "eba6249c8ad34da8",
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"execution_count": 20
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"metadata": {
|
| 42 |
+
"ExecuteTime": {
|
| 43 |
+
"end_time": "2026-05-19T12:42:02.315625Z",
|
| 44 |
+
"start_time": "2026-05-19T12:42:02.302579Z"
|
| 45 |
+
}
|
| 46 |
+
},
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"source": [
|
| 49 |
+
"filename_pattern = re.compile(r\"^train_(?P<dataset>.+)\\.log$\")\n",
|
| 50 |
+
"result_pattern = re.compile(\n",
|
| 51 |
+
" r\"TEST\\s*\\|\\s*\"\n",
|
| 52 |
+
" r\"loss=(?P<loss>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 53 |
+
" r\"plcc=(?P<plcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 54 |
+
" r\"srcc=(?P<srcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 55 |
+
" r\"rmse=(?P<rmse>[-+]?\\d*\\.?\\d+)\"\n",
|
| 56 |
+
")\n",
|
| 57 |
+
"records = []\n",
|
| 58 |
+
"for log_path in LOG_ROOT.rglob(\"*.log\"):\n",
|
| 59 |
+
" filename_match = filename_pattern.match(log_path.name)\n",
|
| 60 |
+
" dataset = filename_match.group(\"dataset\")\n",
|
| 61 |
+
"\n",
|
| 62 |
+
" text = log_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n",
|
| 63 |
+
" matches = list(result_pattern.finditer(text))\n",
|
| 64 |
+
" if not matches:\n",
|
| 65 |
+
" continue\n",
|
| 66 |
+
" records.append({\n",
|
| 67 |
+
" \"dataset\": dataset,\n",
|
| 68 |
+
" \"srcc\": float(matches[-1].group(\"srcc\")),\n",
|
| 69 |
+
" \"plcc\": float(matches[-1].group(\"plcc\")),\n",
|
| 70 |
+
" })\n",
|
| 71 |
+
"train_df = (pd.DataFrame(records).sort_values(\"dataset\", ignore_index=True)[[\"dataset\", \"srcc\", \"plcc\"]])\n",
|
| 72 |
+
"train_df"
|
| 73 |
+
],
|
| 74 |
+
"id": "4f2532e62a94f758",
|
| 75 |
+
"outputs": [
|
| 76 |
+
{
|
| 77 |
+
"data": {
|
| 78 |
+
"text/plain": [
|
| 79 |
+
" dataset srcc plcc\n",
|
| 80 |
+
"0 finevd 0.8596 0.8579\n",
|
| 81 |
+
"1 konvid_1k 0.7970 0.8171\n",
|
| 82 |
+
"2 kvq 0.8768 0.8781\n",
|
| 83 |
+
"3 live_vqc 0.6386 0.7192\n",
|
| 84 |
+
"4 live_yt_gaming 0.8554 0.8755\n",
|
| 85 |
+
"5 lsvq 0.8753 0.8760\n",
|
| 86 |
+
"6 shorts-hdr-dataset_hdr2sdr 0.5431 0.6657\n",
|
| 87 |
+
"7 shorts-hdr-dataset_sdr 0.7884 0.8182\n",
|
| 88 |
+
"8 youtube_ugc 0.8325 0.8548"
|
| 89 |
+
],
|
| 90 |
+
"text/html": [
|
| 91 |
+
"<div>\n",
|
| 92 |
+
"<style scoped>\n",
|
| 93 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 94 |
+
" vertical-align: middle;\n",
|
| 95 |
+
" }\n",
|
| 96 |
+
"\n",
|
| 97 |
+
" .dataframe tbody tr th {\n",
|
| 98 |
+
" vertical-align: top;\n",
|
| 99 |
+
" }\n",
|
| 100 |
+
"\n",
|
| 101 |
+
" .dataframe thead th {\n",
|
| 102 |
+
" text-align: right;\n",
|
| 103 |
+
" }\n",
|
| 104 |
+
"</style>\n",
|
| 105 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 106 |
+
" <thead>\n",
|
| 107 |
+
" <tr style=\"text-align: right;\">\n",
|
| 108 |
+
" <th></th>\n",
|
| 109 |
+
" <th>dataset</th>\n",
|
| 110 |
+
" <th>srcc</th>\n",
|
| 111 |
+
" <th>plcc</th>\n",
|
| 112 |
+
" </tr>\n",
|
| 113 |
+
" </thead>\n",
|
| 114 |
+
" <tbody>\n",
|
| 115 |
+
" <tr>\n",
|
| 116 |
+
" <th>0</th>\n",
|
| 117 |
+
" <td>finevd</td>\n",
|
| 118 |
+
" <td>0.8596</td>\n",
|
| 119 |
+
" <td>0.8579</td>\n",
|
| 120 |
+
" </tr>\n",
|
| 121 |
+
" <tr>\n",
|
| 122 |
+
" <th>1</th>\n",
|
| 123 |
+
" <td>konvid_1k</td>\n",
|
| 124 |
+
" <td>0.7970</td>\n",
|
| 125 |
+
" <td>0.8171</td>\n",
|
| 126 |
+
" </tr>\n",
|
| 127 |
+
" <tr>\n",
|
| 128 |
+
" <th>2</th>\n",
|
| 129 |
+
" <td>kvq</td>\n",
|
| 130 |
+
" <td>0.8768</td>\n",
|
| 131 |
+
" <td>0.8781</td>\n",
|
| 132 |
+
" </tr>\n",
|
| 133 |
+
" <tr>\n",
|
| 134 |
+
" <th>3</th>\n",
|
| 135 |
+
" <td>live_vqc</td>\n",
|
| 136 |
+
" <td>0.6386</td>\n",
|
| 137 |
+
" <td>0.7192</td>\n",
|
| 138 |
+
" </tr>\n",
|
| 139 |
+
" <tr>\n",
|
| 140 |
+
" <th>4</th>\n",
|
| 141 |
+
" <td>live_yt_gaming</td>\n",
|
| 142 |
+
" <td>0.8554</td>\n",
|
| 143 |
+
" <td>0.8755</td>\n",
|
| 144 |
+
" </tr>\n",
|
| 145 |
+
" <tr>\n",
|
| 146 |
+
" <th>5</th>\n",
|
| 147 |
+
" <td>lsvq</td>\n",
|
| 148 |
+
" <td>0.8753</td>\n",
|
| 149 |
+
" <td>0.8760</td>\n",
|
| 150 |
+
" </tr>\n",
|
| 151 |
+
" <tr>\n",
|
| 152 |
+
" <th>6</th>\n",
|
| 153 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 154 |
+
" <td>0.5431</td>\n",
|
| 155 |
+
" <td>0.6657</td>\n",
|
| 156 |
+
" </tr>\n",
|
| 157 |
+
" <tr>\n",
|
| 158 |
+
" <th>7</th>\n",
|
| 159 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 160 |
+
" <td>0.7884</td>\n",
|
| 161 |
+
" <td>0.8182</td>\n",
|
| 162 |
+
" </tr>\n",
|
| 163 |
+
" <tr>\n",
|
| 164 |
+
" <th>8</th>\n",
|
| 165 |
+
" <td>youtube_ugc</td>\n",
|
| 166 |
+
" <td>0.8325</td>\n",
|
| 167 |
+
" <td>0.8548</td>\n",
|
| 168 |
+
" </tr>\n",
|
| 169 |
+
" </tbody>\n",
|
| 170 |
+
"</table>\n",
|
| 171 |
+
"</div>"
|
| 172 |
+
]
|
| 173 |
+
},
|
| 174 |
+
"execution_count": 21,
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"output_type": "execute_result"
|
| 177 |
+
}
|
| 178 |
+
],
|
| 179 |
+
"execution_count": 21
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"cell_type": "markdown",
|
| 184 |
+
"source": "Cross-dataset evaluation: Train on <trainOn> dataset and test on \"testOn\" dataset",
|
| 185 |
+
"id": "68dbaf20d93a49ed"
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"metadata": {
|
| 189 |
+
"ExecuteTime": {
|
| 190 |
+
"end_time": "2026-05-19T12:42:02.318434Z",
|
| 191 |
+
"start_time": "2026-05-19T12:42:02.316713Z"
|
| 192 |
+
}
|
| 193 |
+
},
|
| 194 |
+
"cell_type": "code",
|
| 195 |
+
"source": "LOG_ROOT = Path(\"./checkpoints_transfer\")",
|
| 196 |
+
"id": "e8583a08cc496a7c",
|
| 197 |
+
"outputs": [],
|
| 198 |
+
"execution_count": 22
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"metadata": {
|
| 202 |
+
"ExecuteTime": {
|
| 203 |
+
"end_time": "2026-05-19T12:42:02.328167Z",
|
| 204 |
+
"start_time": "2026-05-19T12:42:02.319595Z"
|
| 205 |
+
}
|
| 206 |
+
},
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"source": [
|
| 209 |
+
"filename_pattern = re.compile(r\"^transfer_(?P<trainOn>.+?)_test_on_(?P<testOn>.+?)\\.log$\")\n",
|
| 210 |
+
"result_pattern = re.compile(\n",
|
| 211 |
+
" r\"TEST_ONLY\\s*\\|\\s*\"\n",
|
| 212 |
+
" r\"loss=(?P<loss>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 213 |
+
" r\"plcc=(?P<plcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 214 |
+
" r\"srcc=(?P<srcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 215 |
+
" r\"rmse=(?P<rmse>[-+]?\\d*\\.?\\d+)\"\n",
|
| 216 |
+
")\n",
|
| 217 |
+
"records = []\n",
|
| 218 |
+
"for log_path in LOG_ROOT.rglob(\"transfer_*_test_on_*.log\"):\n",
|
| 219 |
+
" filename_match = filename_pattern.match(log_path.name)\n",
|
| 220 |
+
" trainOn = filename_match.group(\"trainOn\")\n",
|
| 221 |
+
" testOn = filename_match.group(\"testOn\")\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" text = log_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n",
|
| 224 |
+
" matches = list(result_pattern.finditer(text))\n",
|
| 225 |
+
" if not matches:\n",
|
| 226 |
+
" continue\n",
|
| 227 |
+
" records.append({\n",
|
| 228 |
+
" \"trainOn\": trainOn,\n",
|
| 229 |
+
" \"testOn\": testOn,\n",
|
| 230 |
+
" \"srcc\": float(matches[-1].group(\"srcc\")),\n",
|
| 231 |
+
" \"plcc\": float(matches[-1].group(\"plcc\")),\n",
|
| 232 |
+
" })\n",
|
| 233 |
+
"transfer_df = (pd.DataFrame(records).sort_values([\"trainOn\", \"testOn\"], ignore_index=True))\n",
|
| 234 |
+
"transfer_df"
|
| 235 |
+
],
|
| 236 |
+
"id": "b196c863f89e0e0d",
|
| 237 |
+
"outputs": [
|
| 238 |
+
{
|
| 239 |
+
"data": {
|
| 240 |
+
"text/plain": [
|
| 241 |
+
" trainOn testOn srcc plcc\n",
|
| 242 |
+
"0 kvq shorts-hdr-dataset_hdr2sdr 0.4761 0.5885\n",
|
| 243 |
+
"1 kvq shorts-hdr-dataset_sdr 0.7548 0.8065\n",
|
| 244 |
+
"2 shorts-hdr-dataset_hdr2sdr kvq 0.5976 0.5693\n",
|
| 245 |
+
"3 shorts-hdr-dataset_hdr2sdr shorts-hdr-dataset_sdr 0.6166 0.6802\n",
|
| 246 |
+
"4 shorts-hdr-dataset_sdr kvq 0.7342 0.7451\n",
|
| 247 |
+
"5 shorts-hdr-dataset_sdr shorts-hdr-dataset_hdr2sdr 0.5447 0.6958"
|
| 248 |
+
],
|
| 249 |
+
"text/html": [
|
| 250 |
+
"<div>\n",
|
| 251 |
+
"<style scoped>\n",
|
| 252 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 253 |
+
" vertical-align: middle;\n",
|
| 254 |
+
" }\n",
|
| 255 |
+
"\n",
|
| 256 |
+
" .dataframe tbody tr th {\n",
|
| 257 |
+
" vertical-align: top;\n",
|
| 258 |
+
" }\n",
|
| 259 |
+
"\n",
|
| 260 |
+
" .dataframe thead th {\n",
|
| 261 |
+
" text-align: right;\n",
|
| 262 |
+
" }\n",
|
| 263 |
+
"</style>\n",
|
| 264 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 265 |
+
" <thead>\n",
|
| 266 |
+
" <tr style=\"text-align: right;\">\n",
|
| 267 |
+
" <th></th>\n",
|
| 268 |
+
" <th>trainOn</th>\n",
|
| 269 |
+
" <th>testOn</th>\n",
|
| 270 |
+
" <th>srcc</th>\n",
|
| 271 |
+
" <th>plcc</th>\n",
|
| 272 |
+
" </tr>\n",
|
| 273 |
+
" </thead>\n",
|
| 274 |
+
" <tbody>\n",
|
| 275 |
+
" <tr>\n",
|
| 276 |
+
" <th>0</th>\n",
|
| 277 |
+
" <td>kvq</td>\n",
|
| 278 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 279 |
+
" <td>0.4761</td>\n",
|
| 280 |
+
" <td>0.5885</td>\n",
|
| 281 |
+
" </tr>\n",
|
| 282 |
+
" <tr>\n",
|
| 283 |
+
" <th>1</th>\n",
|
| 284 |
+
" <td>kvq</td>\n",
|
| 285 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 286 |
+
" <td>0.7548</td>\n",
|
| 287 |
+
" <td>0.8065</td>\n",
|
| 288 |
+
" </tr>\n",
|
| 289 |
+
" <tr>\n",
|
| 290 |
+
" <th>2</th>\n",
|
| 291 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 292 |
+
" <td>kvq</td>\n",
|
| 293 |
+
" <td>0.5976</td>\n",
|
| 294 |
+
" <td>0.5693</td>\n",
|
| 295 |
+
" </tr>\n",
|
| 296 |
+
" <tr>\n",
|
| 297 |
+
" <th>3</th>\n",
|
| 298 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 299 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 300 |
+
" <td>0.6166</td>\n",
|
| 301 |
+
" <td>0.6802</td>\n",
|
| 302 |
+
" </tr>\n",
|
| 303 |
+
" <tr>\n",
|
| 304 |
+
" <th>4</th>\n",
|
| 305 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 306 |
+
" <td>kvq</td>\n",
|
| 307 |
+
" <td>0.7342</td>\n",
|
| 308 |
+
" <td>0.7451</td>\n",
|
| 309 |
+
" </tr>\n",
|
| 310 |
+
" <tr>\n",
|
| 311 |
+
" <th>5</th>\n",
|
| 312 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 313 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 314 |
+
" <td>0.5447</td>\n",
|
| 315 |
+
" <td>0.6958</td>\n",
|
| 316 |
+
" </tr>\n",
|
| 317 |
+
" </tbody>\n",
|
| 318 |
+
"</table>\n",
|
| 319 |
+
"</div>"
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
"execution_count": 23,
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"output_type": "execute_result"
|
| 325 |
+
}
|
| 326 |
+
],
|
| 327 |
+
"execution_count": 23
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"cell_type": "markdown",
|
| 332 |
+
"source": "Cross-dataset evaluation: Train on \"trainOn\" dataset and finetune on \"fintuneOn\" dataset",
|
| 333 |
+
"id": "e08409c56d2eaea3"
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"metadata": {
|
| 337 |
+
"ExecuteTime": {
|
| 338 |
+
"end_time": "2026-05-19T12:42:02.337383Z",
|
| 339 |
+
"start_time": "2026-05-19T12:42:02.329308Z"
|
| 340 |
+
}
|
| 341 |
+
},
|
| 342 |
+
"cell_type": "code",
|
| 343 |
+
"source": [
|
| 344 |
+
"filename_pattern = re.compile(r\"^transfer_(?P<trainOn>.+?)_finetune_on_(?P<finetuneOn>.+?)\\.log$\")\n",
|
| 345 |
+
"finetune_result_pattern = re.compile(\n",
|
| 346 |
+
" r\"FINETUNE TEST\\s*\\|\\s*\"\n",
|
| 347 |
+
" r\"loss=(?P<loss>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 348 |
+
" r\"plcc=(?P<plcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 349 |
+
" r\"srcc=(?P<srcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
|
| 350 |
+
" r\"rmse=(?P<rmse>[-+]?\\d*\\.?\\d+)\"\n",
|
| 351 |
+
")\n",
|
| 352 |
+
"records = []\n",
|
| 353 |
+
"for log_path in LOG_ROOT.rglob(\"transfer_*_finetune_on_*.log\"):\n",
|
| 354 |
+
" filename_match = filename_pattern.match(log_path.name)\n",
|
| 355 |
+
" trainOn = filename_match.group(\"trainOn\")\n",
|
| 356 |
+
" finetuneOn = filename_match.group(\"finetuneOn\")\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" text = log_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n",
|
| 359 |
+
" matches = list(finetune_result_pattern.finditer(text))\n",
|
| 360 |
+
" if not matches:\n",
|
| 361 |
+
" continue\n",
|
| 362 |
+
" records.append({\n",
|
| 363 |
+
" \"trainOn\": trainOn,\n",
|
| 364 |
+
" \"finetuneOn\": finetuneOn,\n",
|
| 365 |
+
" \"srcc\": float(matches[-1].group(\"srcc\")),\n",
|
| 366 |
+
" \"plcc\": float(matches[-1].group(\"plcc\")),\n",
|
| 367 |
+
" })\n",
|
| 368 |
+
"finetune_df = (pd.DataFrame(records).sort_values([\"trainOn\", \"finetuneOn\"], ignore_index=True))\n",
|
| 369 |
+
"finetune_df"
|
| 370 |
+
],
|
| 371 |
+
"id": "bfe4a5008377348d",
|
| 372 |
+
"outputs": [
|
| 373 |
+
{
|
| 374 |
+
"data": {
|
| 375 |
+
"text/plain": [
|
| 376 |
+
" trainOn finetuneOn srcc plcc\n",
|
| 377 |
+
"0 kvq shorts-hdr-dataset_hdr2sdr 0.6412 0.7225\n",
|
| 378 |
+
"1 kvq shorts-hdr-dataset_sdr 0.8291 0.8683\n",
|
| 379 |
+
"2 shorts-hdr-dataset_hdr2sdr kvq 0.8743 0.8792\n",
|
| 380 |
+
"3 shorts-hdr-dataset_hdr2sdr shorts-hdr-dataset_sdr 0.8183 0.8561\n",
|
| 381 |
+
"4 shorts-hdr-dataset_sdr kvq 0.8862 0.8883\n",
|
| 382 |
+
"5 shorts-hdr-dataset_sdr shorts-hdr-dataset_hdr2sdr 0.6589 0.8013"
|
| 383 |
+
],
|
| 384 |
+
"text/html": [
|
| 385 |
+
"<div>\n",
|
| 386 |
+
"<style scoped>\n",
|
| 387 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 388 |
+
" vertical-align: middle;\n",
|
| 389 |
+
" }\n",
|
| 390 |
+
"\n",
|
| 391 |
+
" .dataframe tbody tr th {\n",
|
| 392 |
+
" vertical-align: top;\n",
|
| 393 |
+
" }\n",
|
| 394 |
+
"\n",
|
| 395 |
+
" .dataframe thead th {\n",
|
| 396 |
+
" text-align: right;\n",
|
| 397 |
+
" }\n",
|
| 398 |
+
"</style>\n",
|
| 399 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 400 |
+
" <thead>\n",
|
| 401 |
+
" <tr style=\"text-align: right;\">\n",
|
| 402 |
+
" <th></th>\n",
|
| 403 |
+
" <th>trainOn</th>\n",
|
| 404 |
+
" <th>finetuneOn</th>\n",
|
| 405 |
+
" <th>srcc</th>\n",
|
| 406 |
+
" <th>plcc</th>\n",
|
| 407 |
+
" </tr>\n",
|
| 408 |
+
" </thead>\n",
|
| 409 |
+
" <tbody>\n",
|
| 410 |
+
" <tr>\n",
|
| 411 |
+
" <th>0</th>\n",
|
| 412 |
+
" <td>kvq</td>\n",
|
| 413 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 414 |
+
" <td>0.6412</td>\n",
|
| 415 |
+
" <td>0.7225</td>\n",
|
| 416 |
+
" </tr>\n",
|
| 417 |
+
" <tr>\n",
|
| 418 |
+
" <th>1</th>\n",
|
| 419 |
+
" <td>kvq</td>\n",
|
| 420 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 421 |
+
" <td>0.8291</td>\n",
|
| 422 |
+
" <td>0.8683</td>\n",
|
| 423 |
+
" </tr>\n",
|
| 424 |
+
" <tr>\n",
|
| 425 |
+
" <th>2</th>\n",
|
| 426 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 427 |
+
" <td>kvq</td>\n",
|
| 428 |
+
" <td>0.8743</td>\n",
|
| 429 |
+
" <td>0.8792</td>\n",
|
| 430 |
+
" </tr>\n",
|
| 431 |
+
" <tr>\n",
|
| 432 |
+
" <th>3</th>\n",
|
| 433 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 434 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 435 |
+
" <td>0.8183</td>\n",
|
| 436 |
+
" <td>0.8561</td>\n",
|
| 437 |
+
" </tr>\n",
|
| 438 |
+
" <tr>\n",
|
| 439 |
+
" <th>4</th>\n",
|
| 440 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 441 |
+
" <td>kvq</td>\n",
|
| 442 |
+
" <td>0.8862</td>\n",
|
| 443 |
+
" <td>0.8883</td>\n",
|
| 444 |
+
" </tr>\n",
|
| 445 |
+
" <tr>\n",
|
| 446 |
+
" <th>5</th>\n",
|
| 447 |
+
" <td>shorts-hdr-dataset_sdr</td>\n",
|
| 448 |
+
" <td>shorts-hdr-dataset_hdr2sdr</td>\n",
|
| 449 |
+
" <td>0.6589</td>\n",
|
| 450 |
+
" <td>0.8013</td>\n",
|
| 451 |
+
" </tr>\n",
|
| 452 |
+
" </tbody>\n",
|
| 453 |
+
"</table>\n",
|
| 454 |
+
"</div>"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
"execution_count": 24,
|
| 458 |
+
"metadata": {},
|
| 459 |
+
"output_type": "execute_result"
|
| 460 |
+
}
|
| 461 |
+
],
|
| 462 |
+
"execution_count": 24
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"metadata": {},
|
| 466 |
+
"cell_type": "markdown",
|
| 467 |
+
"source": "Complexity_test",
|
| 468 |
+
"id": "544f3bed13c8f3d9"
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"metadata": {
|
| 472 |
+
"ExecuteTime": {
|
| 473 |
+
"end_time": "2026-05-19T12:42:02.339580Z",
|
| 474 |
+
"start_time": "2026-05-19T12:42:02.337969Z"
|
| 475 |
+
}
|
| 476 |
+
},
|
| 477 |
+
"cell_type": "code",
|
| 478 |
+
"source": "complexity_csv_path = \"../test_videos/complexity_test/\"",
|
| 479 |
+
"id": "9d718a191ee6bdf2",
|
| 480 |
+
"outputs": [],
|
| 481 |
+
"execution_count": 25
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"metadata": {
|
| 485 |
+
"ExecuteTime": {
|
| 486 |
+
"end_time": "2026-05-19T12:42:02.347488Z",
|
| 487 |
+
"start_time": "2026-05-19T12:42:02.340191Z"
|
| 488 |
+
}
|
| 489 |
+
},
|
| 490 |
+
"cell_type": "code",
|
| 491 |
+
"source": [
|
| 492 |
+
"resolution_df1 = pd.read_csv(f\"{complexity_csv_path}complexity_resolution_KVQ_0546.csv\").assign(video=\"KVQ_0546\")\n",
|
| 493 |
+
"resolution_df2 = pd.read_csv(f\"{complexity_csv_path}complexity_resolution_SDR_Animal_5ngj.csv\").assign(video=\"SDR_Animal_5ngj\")\n",
|
| 494 |
+
"resolution_df = pd.concat([resolution_df1, resolution_df2], ignore_index=True)\n",
|
| 495 |
+
"resolution_df"
|
| 496 |
+
],
|
| 497 |
+
"id": "5c99ce5680f78c34",
|
| 498 |
+
"outputs": [
|
| 499 |
+
{
|
| 500 |
+
"data": {
|
| 501 |
+
"text/plain": [
|
| 502 |
+
" model 540P 720P 1080P 2160P MOS video\n",
|
| 503 |
+
"0 FAST-VQA 0.638 0.745 1.020 2.680 4.017 KVQ_0546\n",
|
| 504 |
+
"1 FasterVQA 0.428 0.480 0.573 1.077 4.356 KVQ_0546\n",
|
| 505 |
+
"2 DOVER 1.018 1.163 1.511 3.662 4.223 KVQ_0546\n",
|
| 506 |
+
"3 Our 0.364 0.485 0.753 2.437 3.650 KVQ_0546\n",
|
| 507 |
+
"4 FAST-VQA 0.599 0.673 0.909 2.217 3.319 SDR_Animal_5ngj\n",
|
| 508 |
+
"5 FasterVQA 0.489 0.547 0.696 1.343 3.556 SDR_Animal_5ngj\n",
|
| 509 |
+
"6 DOVER 0.920 1.022 1.293 2.783 3.814 SDR_Animal_5ngj\n",
|
| 510 |
+
"7 Our 0.313 0.405 0.697 2.137 3.878 SDR_Animal_5ngj"
|
| 511 |
+
],
|
| 512 |
+
"text/html": [
|
| 513 |
+
"<div>\n",
|
| 514 |
+
"<style scoped>\n",
|
| 515 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 516 |
+
" vertical-align: middle;\n",
|
| 517 |
+
" }\n",
|
| 518 |
+
"\n",
|
| 519 |
+
" .dataframe tbody tr th {\n",
|
| 520 |
+
" vertical-align: top;\n",
|
| 521 |
+
" }\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" .dataframe thead th {\n",
|
| 524 |
+
" text-align: right;\n",
|
| 525 |
+
" }\n",
|
| 526 |
+
"</style>\n",
|
| 527 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 528 |
+
" <thead>\n",
|
| 529 |
+
" <tr style=\"text-align: right;\">\n",
|
| 530 |
+
" <th></th>\n",
|
| 531 |
+
" <th>model</th>\n",
|
| 532 |
+
" <th>540P</th>\n",
|
| 533 |
+
" <th>720P</th>\n",
|
| 534 |
+
" <th>1080P</th>\n",
|
| 535 |
+
" <th>2160P</th>\n",
|
| 536 |
+
" <th>MOS</th>\n",
|
| 537 |
+
" <th>video</th>\n",
|
| 538 |
+
" </tr>\n",
|
| 539 |
+
" </thead>\n",
|
| 540 |
+
" <tbody>\n",
|
| 541 |
+
" <tr>\n",
|
| 542 |
+
" <th>0</th>\n",
|
| 543 |
+
" <td>FAST-VQA</td>\n",
|
| 544 |
+
" <td>0.638</td>\n",
|
| 545 |
+
" <td>0.745</td>\n",
|
| 546 |
+
" <td>1.020</td>\n",
|
| 547 |
+
" <td>2.680</td>\n",
|
| 548 |
+
" <td>4.017</td>\n",
|
| 549 |
+
" <td>KVQ_0546</td>\n",
|
| 550 |
+
" </tr>\n",
|
| 551 |
+
" <tr>\n",
|
| 552 |
+
" <th>1</th>\n",
|
| 553 |
+
" <td>FasterVQA</td>\n",
|
| 554 |
+
" <td>0.428</td>\n",
|
| 555 |
+
" <td>0.480</td>\n",
|
| 556 |
+
" <td>0.573</td>\n",
|
| 557 |
+
" <td>1.077</td>\n",
|
| 558 |
+
" <td>4.356</td>\n",
|
| 559 |
+
" <td>KVQ_0546</td>\n",
|
| 560 |
+
" </tr>\n",
|
| 561 |
+
" <tr>\n",
|
| 562 |
+
" <th>2</th>\n",
|
| 563 |
+
" <td>DOVER</td>\n",
|
| 564 |
+
" <td>1.018</td>\n",
|
| 565 |
+
" <td>1.163</td>\n",
|
| 566 |
+
" <td>1.511</td>\n",
|
| 567 |
+
" <td>3.662</td>\n",
|
| 568 |
+
" <td>4.223</td>\n",
|
| 569 |
+
" <td>KVQ_0546</td>\n",
|
| 570 |
+
" </tr>\n",
|
| 571 |
+
" <tr>\n",
|
| 572 |
+
" <th>3</th>\n",
|
| 573 |
+
" <td>Our</td>\n",
|
| 574 |
+
" <td>0.364</td>\n",
|
| 575 |
+
" <td>0.485</td>\n",
|
| 576 |
+
" <td>0.753</td>\n",
|
| 577 |
+
" <td>2.437</td>\n",
|
| 578 |
+
" <td>3.650</td>\n",
|
| 579 |
+
" <td>KVQ_0546</td>\n",
|
| 580 |
+
" </tr>\n",
|
| 581 |
+
" <tr>\n",
|
| 582 |
+
" <th>4</th>\n",
|
| 583 |
+
" <td>FAST-VQA</td>\n",
|
| 584 |
+
" <td>0.599</td>\n",
|
| 585 |
+
" <td>0.673</td>\n",
|
| 586 |
+
" <td>0.909</td>\n",
|
| 587 |
+
" <td>2.217</td>\n",
|
| 588 |
+
" <td>3.319</td>\n",
|
| 589 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 590 |
+
" </tr>\n",
|
| 591 |
+
" <tr>\n",
|
| 592 |
+
" <th>5</th>\n",
|
| 593 |
+
" <td>FasterVQA</td>\n",
|
| 594 |
+
" <td>0.489</td>\n",
|
| 595 |
+
" <td>0.547</td>\n",
|
| 596 |
+
" <td>0.696</td>\n",
|
| 597 |
+
" <td>1.343</td>\n",
|
| 598 |
+
" <td>3.556</td>\n",
|
| 599 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 600 |
+
" </tr>\n",
|
| 601 |
+
" <tr>\n",
|
| 602 |
+
" <th>6</th>\n",
|
| 603 |
+
" <td>DOVER</td>\n",
|
| 604 |
+
" <td>0.920</td>\n",
|
| 605 |
+
" <td>1.022</td>\n",
|
| 606 |
+
" <td>1.293</td>\n",
|
| 607 |
+
" <td>2.783</td>\n",
|
| 608 |
+
" <td>3.814</td>\n",
|
| 609 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 610 |
+
" </tr>\n",
|
| 611 |
+
" <tr>\n",
|
| 612 |
+
" <th>7</th>\n",
|
| 613 |
+
" <td>Our</td>\n",
|
| 614 |
+
" <td>0.313</td>\n",
|
| 615 |
+
" <td>0.405</td>\n",
|
| 616 |
+
" <td>0.697</td>\n",
|
| 617 |
+
" <td>2.137</td>\n",
|
| 618 |
+
" <td>3.878</td>\n",
|
| 619 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 620 |
+
" </tr>\n",
|
| 621 |
+
" </tbody>\n",
|
| 622 |
+
"</table>\n",
|
| 623 |
+
"</div>"
|
| 624 |
+
]
|
| 625 |
+
},
|
| 626 |
+
"execution_count": 26,
|
| 627 |
+
"metadata": {},
|
| 628 |
+
"output_type": "execute_result"
|
| 629 |
+
}
|
| 630 |
+
],
|
| 631 |
+
"execution_count": 26
|
| 632 |
+
},
|
| 633 |
+
{
|
| 634 |
+
"metadata": {
|
| 635 |
+
"ExecuteTime": {
|
| 636 |
+
"end_time": "2026-05-19T12:42:02.353757Z",
|
| 637 |
+
"start_time": "2026-05-19T12:42:02.348243Z"
|
| 638 |
+
}
|
| 639 |
+
},
|
| 640 |
+
"cell_type": "code",
|
| 641 |
+
"source": [
|
| 642 |
+
"complexity_df = pd.read_csv(f\"{complexity_csv_path}complexity_test.csv\")\n",
|
| 643 |
+
"complexity_df"
|
| 644 |
+
],
|
| 645 |
+
"id": "f639ed0f66a50c57",
|
| 646 |
+
"outputs": [
|
| 647 |
+
{
|
| 648 |
+
"data": {
|
| 649 |
+
"text/plain": [
|
| 650 |
+
" model video Time Predicted_MOS Scaled_MOS MOS\n",
|
| 651 |
+
"0 FAST-VQA SDR_Gameplay_wcq2 0.822 0.248790 1.995 4.156\n",
|
| 652 |
+
"1 FAST-VQA SDR_Gameplay_s0pc 0.820 0.550190 3.201 4.264\n",
|
| 653 |
+
"2 FAST-VQA SDR_Gameplay_z5fz 0.823 0.538800 3.155 4.221\n",
|
| 654 |
+
"3 FAST-VQA SDR_Animal_5ngj 0.846 0.579850 3.319 4.308\n",
|
| 655 |
+
"4 FAST-VQA 0546 0.895 0.754290 4.017 2.389\n",
|
| 656 |
+
"5 FasterVQA SDR_Gameplay_wcq2 0.502 0.299890 2.200 4.156\n",
|
| 657 |
+
"6 FasterVQA SDR_Gameplay_s0pc 0.524 0.706330 3.825 4.264\n",
|
| 658 |
+
"7 FasterVQA SDR_Gameplay_z5fz 0.499 0.400900 2.604 4.221\n",
|
| 659 |
+
"8 FasterVQA SDR_Animal_5ngj 0.500 0.638990 3.556 4.308\n",
|
| 660 |
+
"9 FasterVQA 0546 0.484 0.839070 4.356 2.389\n",
|
| 661 |
+
"10 DOVER SDR_Gameplay_wcq2 1.407 0.347305 2.389 4.156\n",
|
| 662 |
+
"11 DOVER SDR_Gameplay_s0pc 1.345 0.582339 3.329 4.264\n",
|
| 663 |
+
"12 DOVER SDR_Gameplay_z5fz 1.341 0.524504 3.098 4.221\n",
|
| 664 |
+
"13 DOVER SDR_Animal_5ngj 1.326 0.703516 3.814 4.308\n",
|
| 665 |
+
"14 DOVER 0546 1.489 0.805786 4.223 2.389\n",
|
| 666 |
+
"15 Our SDR_Gameplay_wcq2 0.594 59.864200 3.395 4.156\n",
|
| 667 |
+
"16 Our SDR_Gameplay_s0pc 0.605 61.772900 3.471 4.264\n",
|
| 668 |
+
"17 Our SDR_Gameplay_z5fz 0.602 61.077000 3.443 4.221\n",
|
| 669 |
+
"18 Our SDR_Animal_5ngj 0.598 71.957800 3.878 4.308\n",
|
| 670 |
+
"19 Our 0546 0.734 66.257200 3.650 2.389"
|
| 671 |
+
],
|
| 672 |
+
"text/html": [
|
| 673 |
+
"<div>\n",
|
| 674 |
+
"<style scoped>\n",
|
| 675 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 676 |
+
" vertical-align: middle;\n",
|
| 677 |
+
" }\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" .dataframe tbody tr th {\n",
|
| 680 |
+
" vertical-align: top;\n",
|
| 681 |
+
" }\n",
|
| 682 |
+
"\n",
|
| 683 |
+
" .dataframe thead th {\n",
|
| 684 |
+
" text-align: right;\n",
|
| 685 |
+
" }\n",
|
| 686 |
+
"</style>\n",
|
| 687 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 688 |
+
" <thead>\n",
|
| 689 |
+
" <tr style=\"text-align: right;\">\n",
|
| 690 |
+
" <th></th>\n",
|
| 691 |
+
" <th>model</th>\n",
|
| 692 |
+
" <th>video</th>\n",
|
| 693 |
+
" <th>Time</th>\n",
|
| 694 |
+
" <th>Predicted_MOS</th>\n",
|
| 695 |
+
" <th>Scaled_MOS</th>\n",
|
| 696 |
+
" <th>MOS</th>\n",
|
| 697 |
+
" </tr>\n",
|
| 698 |
+
" </thead>\n",
|
| 699 |
+
" <tbody>\n",
|
| 700 |
+
" <tr>\n",
|
| 701 |
+
" <th>0</th>\n",
|
| 702 |
+
" <td>FAST-VQA</td>\n",
|
| 703 |
+
" <td>SDR_Gameplay_wcq2</td>\n",
|
| 704 |
+
" <td>0.822</td>\n",
|
| 705 |
+
" <td>0.248790</td>\n",
|
| 706 |
+
" <td>1.995</td>\n",
|
| 707 |
+
" <td>4.156</td>\n",
|
| 708 |
+
" </tr>\n",
|
| 709 |
+
" <tr>\n",
|
| 710 |
+
" <th>1</th>\n",
|
| 711 |
+
" <td>FAST-VQA</td>\n",
|
| 712 |
+
" <td>SDR_Gameplay_s0pc</td>\n",
|
| 713 |
+
" <td>0.820</td>\n",
|
| 714 |
+
" <td>0.550190</td>\n",
|
| 715 |
+
" <td>3.201</td>\n",
|
| 716 |
+
" <td>4.264</td>\n",
|
| 717 |
+
" </tr>\n",
|
| 718 |
+
" <tr>\n",
|
| 719 |
+
" <th>2</th>\n",
|
| 720 |
+
" <td>FAST-VQA</td>\n",
|
| 721 |
+
" <td>SDR_Gameplay_z5fz</td>\n",
|
| 722 |
+
" <td>0.823</td>\n",
|
| 723 |
+
" <td>0.538800</td>\n",
|
| 724 |
+
" <td>3.155</td>\n",
|
| 725 |
+
" <td>4.221</td>\n",
|
| 726 |
+
" </tr>\n",
|
| 727 |
+
" <tr>\n",
|
| 728 |
+
" <th>3</th>\n",
|
| 729 |
+
" <td>FAST-VQA</td>\n",
|
| 730 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 731 |
+
" <td>0.846</td>\n",
|
| 732 |
+
" <td>0.579850</td>\n",
|
| 733 |
+
" <td>3.319</td>\n",
|
| 734 |
+
" <td>4.308</td>\n",
|
| 735 |
+
" </tr>\n",
|
| 736 |
+
" <tr>\n",
|
| 737 |
+
" <th>4</th>\n",
|
| 738 |
+
" <td>FAST-VQA</td>\n",
|
| 739 |
+
" <td>0546</td>\n",
|
| 740 |
+
" <td>0.895</td>\n",
|
| 741 |
+
" <td>0.754290</td>\n",
|
| 742 |
+
" <td>4.017</td>\n",
|
| 743 |
+
" <td>2.389</td>\n",
|
| 744 |
+
" </tr>\n",
|
| 745 |
+
" <tr>\n",
|
| 746 |
+
" <th>5</th>\n",
|
| 747 |
+
" <td>FasterVQA</td>\n",
|
| 748 |
+
" <td>SDR_Gameplay_wcq2</td>\n",
|
| 749 |
+
" <td>0.502</td>\n",
|
| 750 |
+
" <td>0.299890</td>\n",
|
| 751 |
+
" <td>2.200</td>\n",
|
| 752 |
+
" <td>4.156</td>\n",
|
| 753 |
+
" </tr>\n",
|
| 754 |
+
" <tr>\n",
|
| 755 |
+
" <th>6</th>\n",
|
| 756 |
+
" <td>FasterVQA</td>\n",
|
| 757 |
+
" <td>SDR_Gameplay_s0pc</td>\n",
|
| 758 |
+
" <td>0.524</td>\n",
|
| 759 |
+
" <td>0.706330</td>\n",
|
| 760 |
+
" <td>3.825</td>\n",
|
| 761 |
+
" <td>4.264</td>\n",
|
| 762 |
+
" </tr>\n",
|
| 763 |
+
" <tr>\n",
|
| 764 |
+
" <th>7</th>\n",
|
| 765 |
+
" <td>FasterVQA</td>\n",
|
| 766 |
+
" <td>SDR_Gameplay_z5fz</td>\n",
|
| 767 |
+
" <td>0.499</td>\n",
|
| 768 |
+
" <td>0.400900</td>\n",
|
| 769 |
+
" <td>2.604</td>\n",
|
| 770 |
+
" <td>4.221</td>\n",
|
| 771 |
+
" </tr>\n",
|
| 772 |
+
" <tr>\n",
|
| 773 |
+
" <th>8</th>\n",
|
| 774 |
+
" <td>FasterVQA</td>\n",
|
| 775 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 776 |
+
" <td>0.500</td>\n",
|
| 777 |
+
" <td>0.638990</td>\n",
|
| 778 |
+
" <td>3.556</td>\n",
|
| 779 |
+
" <td>4.308</td>\n",
|
| 780 |
+
" </tr>\n",
|
| 781 |
+
" <tr>\n",
|
| 782 |
+
" <th>9</th>\n",
|
| 783 |
+
" <td>FasterVQA</td>\n",
|
| 784 |
+
" <td>0546</td>\n",
|
| 785 |
+
" <td>0.484</td>\n",
|
| 786 |
+
" <td>0.839070</td>\n",
|
| 787 |
+
" <td>4.356</td>\n",
|
| 788 |
+
" <td>2.389</td>\n",
|
| 789 |
+
" </tr>\n",
|
| 790 |
+
" <tr>\n",
|
| 791 |
+
" <th>10</th>\n",
|
| 792 |
+
" <td>DOVER</td>\n",
|
| 793 |
+
" <td>SDR_Gameplay_wcq2</td>\n",
|
| 794 |
+
" <td>1.407</td>\n",
|
| 795 |
+
" <td>0.347305</td>\n",
|
| 796 |
+
" <td>2.389</td>\n",
|
| 797 |
+
" <td>4.156</td>\n",
|
| 798 |
+
" </tr>\n",
|
| 799 |
+
" <tr>\n",
|
| 800 |
+
" <th>11</th>\n",
|
| 801 |
+
" <td>DOVER</td>\n",
|
| 802 |
+
" <td>SDR_Gameplay_s0pc</td>\n",
|
| 803 |
+
" <td>1.345</td>\n",
|
| 804 |
+
" <td>0.582339</td>\n",
|
| 805 |
+
" <td>3.329</td>\n",
|
| 806 |
+
" <td>4.264</td>\n",
|
| 807 |
+
" </tr>\n",
|
| 808 |
+
" <tr>\n",
|
| 809 |
+
" <th>12</th>\n",
|
| 810 |
+
" <td>DOVER</td>\n",
|
| 811 |
+
" <td>SDR_Gameplay_z5fz</td>\n",
|
| 812 |
+
" <td>1.341</td>\n",
|
| 813 |
+
" <td>0.524504</td>\n",
|
| 814 |
+
" <td>3.098</td>\n",
|
| 815 |
+
" <td>4.221</td>\n",
|
| 816 |
+
" </tr>\n",
|
| 817 |
+
" <tr>\n",
|
| 818 |
+
" <th>13</th>\n",
|
| 819 |
+
" <td>DOVER</td>\n",
|
| 820 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 821 |
+
" <td>1.326</td>\n",
|
| 822 |
+
" <td>0.703516</td>\n",
|
| 823 |
+
" <td>3.814</td>\n",
|
| 824 |
+
" <td>4.308</td>\n",
|
| 825 |
+
" </tr>\n",
|
| 826 |
+
" <tr>\n",
|
| 827 |
+
" <th>14</th>\n",
|
| 828 |
+
" <td>DOVER</td>\n",
|
| 829 |
+
" <td>0546</td>\n",
|
| 830 |
+
" <td>1.489</td>\n",
|
| 831 |
+
" <td>0.805786</td>\n",
|
| 832 |
+
" <td>4.223</td>\n",
|
| 833 |
+
" <td>2.389</td>\n",
|
| 834 |
+
" </tr>\n",
|
| 835 |
+
" <tr>\n",
|
| 836 |
+
" <th>15</th>\n",
|
| 837 |
+
" <td>Our</td>\n",
|
| 838 |
+
" <td>SDR_Gameplay_wcq2</td>\n",
|
| 839 |
+
" <td>0.594</td>\n",
|
| 840 |
+
" <td>59.864200</td>\n",
|
| 841 |
+
" <td>3.395</td>\n",
|
| 842 |
+
" <td>4.156</td>\n",
|
| 843 |
+
" </tr>\n",
|
| 844 |
+
" <tr>\n",
|
| 845 |
+
" <th>16</th>\n",
|
| 846 |
+
" <td>Our</td>\n",
|
| 847 |
+
" <td>SDR_Gameplay_s0pc</td>\n",
|
| 848 |
+
" <td>0.605</td>\n",
|
| 849 |
+
" <td>61.772900</td>\n",
|
| 850 |
+
" <td>3.471</td>\n",
|
| 851 |
+
" <td>4.264</td>\n",
|
| 852 |
+
" </tr>\n",
|
| 853 |
+
" <tr>\n",
|
| 854 |
+
" <th>17</th>\n",
|
| 855 |
+
" <td>Our</td>\n",
|
| 856 |
+
" <td>SDR_Gameplay_z5fz</td>\n",
|
| 857 |
+
" <td>0.602</td>\n",
|
| 858 |
+
" <td>61.077000</td>\n",
|
| 859 |
+
" <td>3.443</td>\n",
|
| 860 |
+
" <td>4.221</td>\n",
|
| 861 |
+
" </tr>\n",
|
| 862 |
+
" <tr>\n",
|
| 863 |
+
" <th>18</th>\n",
|
| 864 |
+
" <td>Our</td>\n",
|
| 865 |
+
" <td>SDR_Animal_5ngj</td>\n",
|
| 866 |
+
" <td>0.598</td>\n",
|
| 867 |
+
" <td>71.957800</td>\n",
|
| 868 |
+
" <td>3.878</td>\n",
|
| 869 |
+
" <td>4.308</td>\n",
|
| 870 |
+
" </tr>\n",
|
| 871 |
+
" <tr>\n",
|
| 872 |
+
" <th>19</th>\n",
|
| 873 |
+
" <td>Our</td>\n",
|
| 874 |
+
" <td>0546</td>\n",
|
| 875 |
+
" <td>0.734</td>\n",
|
| 876 |
+
" <td>66.257200</td>\n",
|
| 877 |
+
" <td>3.650</td>\n",
|
| 878 |
+
" <td>2.389</td>\n",
|
| 879 |
+
" </tr>\n",
|
| 880 |
+
" </tbody>\n",
|
| 881 |
+
"</table>\n",
|
| 882 |
+
"</div>"
|
| 883 |
+
]
|
| 884 |
+
},
|
| 885 |
+
"execution_count": 27,
|
| 886 |
+
"metadata": {},
|
| 887 |
+
"output_type": "execute_result"
|
| 888 |
+
}
|
| 889 |
+
],
|
| 890 |
+
"execution_count": 27
|
| 891 |
+
}
|
| 892 |
+
],
|
| 893 |
+
"metadata": {
|
| 894 |
+
"kernelspec": {
|
| 895 |
+
"display_name": "Python 3",
|
| 896 |
+
"language": "python",
|
| 897 |
+
"name": "python3"
|
| 898 |
+
},
|
| 899 |
+
"language_info": {
|
| 900 |
+
"codemirror_mode": {
|
| 901 |
+
"name": "ipython",
|
| 902 |
+
"version": 2
|
| 903 |
+
},
|
| 904 |
+
"file_extension": ".py",
|
| 905 |
+
"mimetype": "text/x-python",
|
| 906 |
+
"name": "python",
|
| 907 |
+
"nbconvert_exporter": "python",
|
| 908 |
+
"pygments_lexer": "ipython2",
|
| 909 |
+
"version": "2.7.6"
|
| 910 |
+
}
|
| 911 |
+
},
|
| 912 |
+
"nbformat": 4,
|
| 913 |
+
"nbformat_minor": 5
|
| 914 |
+
}
|
src/demo_test.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ.setdefault("DECORD_DUPLICATE_WARNING_THRESHOLD", "1.0")
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.amp import autocast
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from thop import profile
|
| 14 |
+
from thop import clever_format
|
| 15 |
+
|
| 16 |
+
from train import VQADataset
|
| 17 |
+
from model.qd_model import QD_MODEL
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_checkpoint(ckpt_path, device):
|
| 21 |
+
ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=True)
|
| 22 |
+
if isinstance(ckpt, dict) and "model" in ckpt:
|
| 23 |
+
return {
|
| 24 |
+
"state_dict": ckpt["model"],
|
| 25 |
+
"train_mos_mean": ckpt.get("mos_mean"),
|
| 26 |
+
"train_mos_std": ckpt.get("mos_std"),
|
| 27 |
+
"train_args": ckpt.get("args", {}),
|
| 28 |
+
"is_full_checkpoint": True,
|
| 29 |
+
}
|
| 30 |
+
if isinstance(ckpt, dict):
|
| 31 |
+
return {
|
| 32 |
+
"state_dict": ckpt,
|
| 33 |
+
"train_mos_mean": None,
|
| 34 |
+
"train_mos_std": None,
|
| 35 |
+
"train_args": {},
|
| 36 |
+
"is_full_checkpoint": False,
|
| 37 |
+
}
|
| 38 |
+
raise TypeError(f"Unsupported checkpoint type: {type(ckpt)!r}")
|
| 39 |
+
|
| 40 |
+
class ForwardWrapper(nn.Module):
|
| 41 |
+
def __init__(self, model):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.model = model
|
| 44 |
+
|
| 45 |
+
def forward(self, rgb, w_art, w_str):
|
| 46 |
+
yhat, _aux = self.model(rgb, w_art, w_str)
|
| 47 |
+
return yhat
|
| 48 |
+
|
| 49 |
+
def count_parameters(model):
|
| 50 |
+
total = sum(p.numel() for p in model.parameters())
|
| 51 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 52 |
+
return total, trainable
|
| 53 |
+
|
| 54 |
+
def build_dataset(args, mos_mean, mos_std):
|
| 55 |
+
rows = [(args.video_id, float(args.dummy_mos))]
|
| 56 |
+
dataset = VQADataset(
|
| 57 |
+
rows,
|
| 58 |
+
args.db_path,
|
| 59 |
+
clip_len=args.clip_len,
|
| 60 |
+
size=args.resize,
|
| 61 |
+
win=args.win,
|
| 62 |
+
win_step=args.win_step,
|
| 63 |
+
mos_mean=float(mos_mean),
|
| 64 |
+
mos_std=float(mos_std),
|
| 65 |
+
)
|
| 66 |
+
return dataset
|
| 67 |
+
|
| 68 |
+
def prepare_single_sample(dataset, device):
|
| 69 |
+
rgb, w_art, w_str, y, vid = dataset[0]
|
| 70 |
+
|
| 71 |
+
rgb = rgb.unsqueeze(0).to(device, non_blocking=True)
|
| 72 |
+
w_art = w_art.unsqueeze(0).to(device, non_blocking=True)
|
| 73 |
+
w_str = w_str.unsqueeze(0).to(device, non_blocking=True)
|
| 74 |
+
y = y.unsqueeze(0).to(device, non_blocking=True).float()
|
| 75 |
+
|
| 76 |
+
if isinstance(vid, (list, tuple)):
|
| 77 |
+
video_id = vid[0]
|
| 78 |
+
else:
|
| 79 |
+
video_id = vid
|
| 80 |
+
return rgb, w_art, w_str, y, video_id
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@torch.no_grad()
|
| 84 |
+
def predict_once(
|
| 85 |
+
model,
|
| 86 |
+
rgb,
|
| 87 |
+
w_art,
|
| 88 |
+
w_str,
|
| 89 |
+
*,
|
| 90 |
+
device,
|
| 91 |
+
amp,
|
| 92 |
+
train_mos_mean,
|
| 93 |
+
train_mos_std,
|
| 94 |
+
):
|
| 95 |
+
model.eval()
|
| 96 |
+
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
|
| 97 |
+
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
|
| 98 |
+
yhat, _aux = model(rgb, w_art, w_str)
|
| 99 |
+
|
| 100 |
+
pred_score = yhat.detach().float().cpu() * float(train_mos_std) + float(train_mos_mean)
|
| 101 |
+
return float(pred_score.squeeze().item())
|
| 102 |
+
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def profile_with_thop(model, rgb, w_art, w_str):
|
| 105 |
+
macs, params = profile(model, inputs=(rgb, w_art, w_str), verbose=False)
|
| 106 |
+
flops = 2 * macs
|
| 107 |
+
macs, flops, params = clever_format([macs, flops, params], "%.3f")
|
| 108 |
+
return macs, flops, params
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def benchmark_forward(model, rgb, w_art, w_str, *, device, amp, num_runs=10, warmup=3):
|
| 112 |
+
model.eval()
|
| 113 |
+
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
|
| 114 |
+
|
| 115 |
+
for _ in range(max(0, warmup)):
|
| 116 |
+
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
|
| 117 |
+
_ = model(rgb, w_art, w_str)
|
| 118 |
+
if device_type == "cuda":
|
| 119 |
+
torch.cuda.synchronize()
|
| 120 |
+
|
| 121 |
+
start = time.perf_counter()
|
| 122 |
+
for _ in range(int(num_runs)):
|
| 123 |
+
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
|
| 124 |
+
_ = model(rgb, w_art, w_str)
|
| 125 |
+
if device_type == "cuda":
|
| 126 |
+
torch.cuda.synchronize()
|
| 127 |
+
return (time.perf_counter() - start) / max(1, int(num_runs))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def run_end_to_end_once(args, model, train_mos_mean, train_mos_std, device, amp):
|
| 131 |
+
start = time.perf_counter()
|
| 132 |
+
dataset = build_dataset(args, train_mos_mean, train_mos_std)
|
| 133 |
+
rgb, w_art, w_str, _y, _video_id = prepare_single_sample(dataset, device)
|
| 134 |
+
|
| 135 |
+
pred_score = predict_once(
|
| 136 |
+
model,
|
| 137 |
+
rgb,
|
| 138 |
+
w_art,
|
| 139 |
+
w_str,
|
| 140 |
+
device=device,
|
| 141 |
+
amp=amp,
|
| 142 |
+
train_mos_mean=float(train_mos_mean),
|
| 143 |
+
train_mos_std=float(train_mos_std),
|
| 144 |
+
)
|
| 145 |
+
if str(device).startswith("cuda"):
|
| 146 |
+
torch.cuda.synchronize()
|
| 147 |
+
elapsed = time.perf_counter() - start
|
| 148 |
+
return elapsed, pred_score, rgb, w_art, w_str
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def parse_args():
|
| 152 |
+
ap = argparse.ArgumentParser(description="Demo-style single-video test for QD_MODEL")
|
| 153 |
+
# for complexity time test:
|
| 154 |
+
ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/lsvq/qd_model.best.pt")
|
| 155 |
+
ap.add_argument("--db_path", type=str, default="/home/xinyi/Project/FD-VQA/test_videos/")
|
| 156 |
+
ap.add_argument("--video_id", type=str, default="SDR_Animal_5ngj")
|
| 157 |
+
# for resolution compelxity test:
|
| 158 |
+
# ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/kvq/qd_model.best.pt")
|
| 159 |
+
# ap.add_argument("--db_path", type=str, default="/home/xinyi/Project/FD-VQA/test_videos/complexity_test/complexity_resolution/")
|
| 160 |
+
# ap.add_argument("--video_id", type=str, default="SDR_Animal_5ngj_540p")
|
| 161 |
+
|
| 162 |
+
ap.add_argument("--clip_len", type=int, default=16)
|
| 163 |
+
ap.add_argument("--resize", type=int, default=224)
|
| 164 |
+
ap.add_argument("--win", type=int, default=6)
|
| 165 |
+
ap.add_argument("--win_step", type=int, default=1)
|
| 166 |
+
|
| 167 |
+
ap.add_argument("--device", type=str, default="cuda")
|
| 168 |
+
ap.add_argument("--no_amp", action="store_true")
|
| 169 |
+
|
| 170 |
+
ap.add_argument("--dummy_mos", type=float, default=3.0, help="Only used to compute VQADataset, does not affect prediction")
|
| 171 |
+
ap.add_argument("--num_runs", type=int, default=10, help="Average N runs")
|
| 172 |
+
ap.add_argument("--warmup_runs", type=int, default=3)
|
| 173 |
+
ap.add_argument("--skip_profile", action="store_true")
|
| 174 |
+
return ap.parse_args()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def main():
|
| 178 |
+
args = parse_args()
|
| 179 |
+
device = torch.device(args.device)
|
| 180 |
+
amp = not bool(args.no_amp)
|
| 181 |
+
|
| 182 |
+
print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
|
| 183 |
+
|
| 184 |
+
display_path = str(Path(args.db_path) / args.video_id)
|
| 185 |
+
info = pd.DataFrame([
|
| 186 |
+
{
|
| 187 |
+
"vid": args.video_id,
|
| 188 |
+
"test_video_path": display_path,
|
| 189 |
+
}
|
| 190 |
+
])
|
| 191 |
+
print(info)
|
| 192 |
+
|
| 193 |
+
dataset_preview = build_dataset(args, mos_mean=args.dummy_mos, mos_std=1.0)
|
| 194 |
+
print(f"Dataset loaded. Total videos: {len(dataset_preview)}, Total batches: 1")
|
| 195 |
+
print(f"Loading model from: {args.ckpt_path}")
|
| 196 |
+
|
| 197 |
+
ckpt_info = load_checkpoint(Path(args.ckpt_path), device)
|
| 198 |
+
train_mos_mean = ckpt_info["train_mos_mean"]
|
| 199 |
+
train_mos_std = ckpt_info["train_mos_std"]
|
| 200 |
+
if train_mos_mean is None or train_mos_std is None:
|
| 201 |
+
raise ValueError("Checkpoint does not contain mos_mean / mos_std. Please use a full checkpoint.")
|
| 202 |
+
if float(train_mos_std) <= 1e-8:
|
| 203 |
+
raise ValueError("train_mos_std must be > 0")
|
| 204 |
+
|
| 205 |
+
model = QD_MODEL(clip_model="openai/clip-vit-base-patch16").to(device)
|
| 206 |
+
model.load_state_dict(ckpt_info["state_dict"], strict=True)
|
| 207 |
+
model.eval()
|
| 208 |
+
|
| 209 |
+
run_times = []
|
| 210 |
+
pred_score = None
|
| 211 |
+
rgb = w_art = w_str = None
|
| 212 |
+
|
| 213 |
+
for i in range(args.num_runs):
|
| 214 |
+
for _ in tqdm(range(1), desc="Processing Videos"):
|
| 215 |
+
elapsed, pred_score, rgb, w_art, w_str = run_end_to_end_once(
|
| 216 |
+
args, model, train_mos_mean, train_mos_std, device, amp
|
| 217 |
+
)
|
| 218 |
+
run_times.append(elapsed)
|
| 219 |
+
print(f"Run {i + 1} - Time taken: {elapsed:.4f} seconds")
|
| 220 |
+
|
| 221 |
+
avg_total_time = sum(run_times) / max(1, len(run_times))
|
| 222 |
+
avg_forward_time = benchmark_forward(
|
| 223 |
+
model,
|
| 224 |
+
rgb,
|
| 225 |
+
w_art,
|
| 226 |
+
w_str,
|
| 227 |
+
device=device,
|
| 228 |
+
amp=amp,
|
| 229 |
+
num_runs=args.num_runs,
|
| 230 |
+
warmup=args.warmup_runs,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
total_params, trainable_params = count_parameters(model)
|
| 234 |
+
macs = flops = params = None
|
| 235 |
+
if not args.skip_profile:
|
| 236 |
+
try:
|
| 237 |
+
macs, flops, params = profile_with_thop(model, rgb, w_art, w_str)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"[WARN] THOP profiling failed: {e}")
|
| 240 |
+
|
| 241 |
+
print(f"Average running time over {args.num_runs} runs: {avg_total_time:.4f} seconds")
|
| 242 |
+
print(f"Predicted Quality Score: {pred_score:.4f}")
|
| 243 |
+
print("\n========== PROFILE SUMMARY ==========")
|
| 244 |
+
print(f"video_id : {args.video_id}")
|
| 245 |
+
print(f"rgb shape : {tuple(rgb.shape)}")
|
| 246 |
+
print(f"w_art shape : {tuple(w_art.shape)}")
|
| 247 |
+
print(f"w_str shape : {tuple(w_str.shape)}")
|
| 248 |
+
print(f"train mos mean/std : {float(train_mos_mean):.6f} / {float(train_mos_std):.6f}")
|
| 249 |
+
print(f"predicted score : {pred_score:.6f}")
|
| 250 |
+
print(f"params total : {total_params:,} ({total_params / 1e6:.3f} M)")
|
| 251 |
+
print(f"params trainable : {trainable_params:,} ({trainable_params / 1e6:.3f} M)")
|
| 252 |
+
if params is not None:
|
| 253 |
+
print(f"Params (THOP) : {params} M")
|
| 254 |
+
if macs is not None:
|
| 255 |
+
print(f"MACs (THOP) : {macs} G")
|
| 256 |
+
if flops is not None:
|
| 257 |
+
print(f"FLOPs (~2*MACs) : {flops} G")
|
| 258 |
+
print(f"avg forward time : {avg_forward_time:.6f} s (runs={args.num_runs})")
|
| 259 |
+
print(f"avg end-to-end time : {avg_total_time:.6f} s (sample prep + forward)")
|
| 260 |
+
print("=====================================")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
src/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src/model/__init__.py
|
src/model/clip_dense_encoder.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import CLIPModel
|
| 4 |
+
from transformers.utils import logging
|
| 5 |
+
logging.set_verbosity_error()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CLIPDenseEncoder(nn.Module):
|
| 9 |
+
def __init__(self, model_name="openai/clip-vit-base-patch16"):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.clip = CLIPModel.from_pretrained(str(model_name), use_safetensors=True)
|
| 12 |
+
|
| 13 |
+
self.vision = self.clip.vision_model
|
| 14 |
+
vcfg = self.clip.config.vision_config
|
| 15 |
+
self.hidden_size = int(vcfg.hidden_size)
|
| 16 |
+
self.patch_size = int(vcfg.patch_size)
|
| 17 |
+
self.image_size = int(vcfg.image_size)
|
| 18 |
+
|
| 19 |
+
# OpenAI CLIP normalization constants
|
| 20 |
+
self.register_buffer("_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1))
|
| 21 |
+
self.register_buffer("_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1))
|
| 22 |
+
|
| 23 |
+
def forward(self, x01):
|
| 24 |
+
x = (x01 - self._mean.to(dtype=x01.dtype)) / self._std.to(dtype=x01.dtype)
|
| 25 |
+
out = self.vision(pixel_values=x)
|
| 26 |
+
tokens = out.last_hidden_state[:, 1:, :] # [B, 1+N, C] drop CLS -> [B, N, C]
|
| 27 |
+
b, n, c = tokens.shape
|
| 28 |
+
side = int(n**0.5)
|
| 29 |
+
if side * side != n:
|
| 30 |
+
raise RuntimeError(f"CLIP patch tokens N={n} not a square; cannot reshape to 2D grid.")
|
| 31 |
+
|
| 32 |
+
fmap = tokens.transpose(1, 2).contiguous().view(b, c, side, side)
|
| 33 |
+
return fmap
|
| 34 |
+
|
| 35 |
+
def freeze_all(self):
|
| 36 |
+
for p in self.clip.parameters():
|
| 37 |
+
p.requires_grad = False
|
| 38 |
+
|
| 39 |
+
def unfreeze_last_blocks(self, n_blocks=2, also_unfreeze_ln=True):
|
| 40 |
+
self.freeze_all()
|
| 41 |
+
layers = self.vision.encoder.layers
|
| 42 |
+
n_total = len(layers)
|
| 43 |
+
|
| 44 |
+
# all unfreeze
|
| 45 |
+
if n_blocks < 0 or n_blocks >= n_total:
|
| 46 |
+
for p in self.vision.parameters():
|
| 47 |
+
p.requires_grad = True
|
| 48 |
+
# unfreeze n blocks
|
| 49 |
+
else:
|
| 50 |
+
k = max(0, min(int(n_blocks), n_total))
|
| 51 |
+
for i in range(n_total - k, n_total):
|
| 52 |
+
for p in layers[i].parameters():
|
| 53 |
+
p.requires_grad = True
|
| 54 |
+
|
| 55 |
+
if also_unfreeze_ln:
|
| 56 |
+
if hasattr(self.vision, "pre_layrnorm"):
|
| 57 |
+
for p in self.vision.pre_layrnorm.parameters():
|
| 58 |
+
p.requires_grad = True
|
| 59 |
+
if hasattr(self.vision, "post_layernorm"):
|
| 60 |
+
for p in self.vision.post_layernorm.parameters():
|
| 61 |
+
p.requires_grad = True
|
src/model/qd_model.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ----------------------------
|
| 2 |
+
# Model: CLIP encoder + w_art/w_str pooling -> weighted sequences -> seq model -> head fusion
|
| 3 |
+
# ----------------------------
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .clip_dense_encoder import CLIPDenseEncoder
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def weighted_pool_2d(fmap, wmap, eps=1e-6):
|
| 11 |
+
"""
|
| 12 |
+
fmap: [B, C, H, W]
|
| 13 |
+
wmap: [B, 1, H, W]
|
| 14 |
+
-> [B, C]
|
| 15 |
+
"""
|
| 16 |
+
w = wmap.clamp(0.0, 1.0)
|
| 17 |
+
w = w / (w.sum(dim=(2, 3), keepdim=True) + eps)
|
| 18 |
+
return (fmap * w).sum(dim=(2, 3))
|
| 19 |
+
|
| 20 |
+
def default_gate_stats(w_art, w_str, fmap_bt, eps=1e-6):
|
| 21 |
+
"""
|
| 22 |
+
w_art: [B, 1, T, H, W]
|
| 23 |
+
w_str: [B, 1, T, H, W]
|
| 24 |
+
fmap_bt: [B, T, C, H, W]
|
| 25 |
+
-> [B, 3]
|
| 26 |
+
"""
|
| 27 |
+
mu_art = w_art.mean(dim=(1, 2, 3, 4))
|
| 28 |
+
mu_str = w_str.mean(dim=(1, 2, 3, 4))
|
| 29 |
+
mu_raw = fmap_bt.abs().mean(dim=(1, 2, 3, 4))
|
| 30 |
+
|
| 31 |
+
stats = torch.stack([mu_art, mu_str, mu_raw], dim=1)
|
| 32 |
+
return stats
|
| 33 |
+
|
| 34 |
+
def two_gate_stats(w_art, w_str):
|
| 35 |
+
# [B, 2]
|
| 36 |
+
mu_art = w_art.mean(dim=(1, 2, 3, 4))
|
| 37 |
+
mu_str = w_str.mean(dim=(1, 2, 3, 4))
|
| 38 |
+
return torch.stack([mu_art, mu_str], dim=1)
|
| 39 |
+
|
| 40 |
+
class QD_MODEL(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
*,
|
| 44 |
+
clip_model="openai/clip-vit-base-patch16",
|
| 45 |
+
head_hidden=384,
|
| 46 |
+
gate_hidden=32,
|
| 47 |
+
head_dropout=0.2,
|
| 48 |
+
gate_dropout=0.1,
|
| 49 |
+
ablation_mode="full", # "full" | "art" | "str" | "raw"
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.ablation_mode = ablation_mode
|
| 53 |
+
|
| 54 |
+
self.encoder = CLIPDenseEncoder(model_name=str(clip_model))
|
| 55 |
+
c = int(self.encoder.hidden_size)
|
| 56 |
+
|
| 57 |
+
# 3 heads
|
| 58 |
+
self.head_art = nn.Sequential(
|
| 59 |
+
nn.Linear(c, head_hidden),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
nn.Dropout(head_dropout),
|
| 62 |
+
nn.Linear(head_hidden, head_hidden // 2),
|
| 63 |
+
nn.ReLU(),
|
| 64 |
+
nn.Dropout(head_dropout),
|
| 65 |
+
nn.Linear(head_hidden // 2, 1),
|
| 66 |
+
)
|
| 67 |
+
self.head_str = nn.Sequential(
|
| 68 |
+
nn.Linear(c, head_hidden),
|
| 69 |
+
nn.ReLU(),
|
| 70 |
+
nn.Dropout(head_dropout),
|
| 71 |
+
nn.Linear(head_hidden, head_hidden // 2),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Dropout(head_dropout),
|
| 74 |
+
nn.Linear(head_hidden // 2, 1),
|
| 75 |
+
)
|
| 76 |
+
self.head_fmap = nn.Sequential(
|
| 77 |
+
nn.Linear(c, head_hidden),
|
| 78 |
+
nn.ReLU(),
|
| 79 |
+
nn.Dropout(head_dropout),
|
| 80 |
+
nn.Linear(head_hidden, head_hidden // 2),
|
| 81 |
+
nn.ReLU(),
|
| 82 |
+
nn.Dropout(head_dropout),
|
| 83 |
+
nn.Linear(head_hidden // 2, 1),
|
| 84 |
+
)
|
| 85 |
+
# full model: 3-way gate
|
| 86 |
+
self.gate = nn.Sequential(
|
| 87 |
+
nn.Linear(3, gate_hidden),
|
| 88 |
+
nn.ReLU(),
|
| 89 |
+
nn.Dropout(gate_dropout),
|
| 90 |
+
nn.Linear(gate_hidden, 3),
|
| 91 |
+
)
|
| 92 |
+
# ablation: two gates only
|
| 93 |
+
self.gate_two = nn.Sequential(
|
| 94 |
+
nn.Linear(2, gate_hidden),
|
| 95 |
+
nn.ReLU(),
|
| 96 |
+
nn.Dropout(gate_dropout),
|
| 97 |
+
nn.Linear(gate_hidden, 2),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def freeze_clip_all(self):
|
| 101 |
+
self.encoder.freeze_all()
|
| 102 |
+
|
| 103 |
+
def unfreeze_clip_last_blocks(self, n_blocks=2, also_unfreeze_ln=True):
|
| 104 |
+
self.encoder.unfreeze_last_blocks(
|
| 105 |
+
n_blocks=n_blocks,
|
| 106 |
+
also_unfreeze_ln=also_unfreeze_ln,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def temporal_pool(self, x):
|
| 110 |
+
"""
|
| 111 |
+
x: [B, T, C]
|
| 112 |
+
return: [B, C]
|
| 113 |
+
"""
|
| 114 |
+
return x.mean(dim=1)
|
| 115 |
+
|
| 116 |
+
def forward(self, rgb, w_art, w_str, gate_stats=None):
|
| 117 |
+
"""
|
| 118 |
+
rgb: [B,3,T,H,W] (rgb in [0,1])
|
| 119 |
+
w_art/w_str: [B,1,T,H,W] (0..1)
|
| 120 |
+
gate_stats: [B,3]
|
| 121 |
+
"""
|
| 122 |
+
B, _C, T, H, W = rgb.shape
|
| 123 |
+
|
| 124 |
+
x2d = rgb.permute(0, 2, 1, 3, 4).contiguous().view(B * T, 3, H, W)
|
| 125 |
+
fmap2d = self.encoder(x2d)
|
| 126 |
+
_, C, Hp, Wp = fmap2d.shape # (B*T, 768, 14, 14)
|
| 127 |
+
|
| 128 |
+
w_art_bt = w_art.transpose(1, 2) # (B, T, 1, 14, 14)
|
| 129 |
+
w_str_bt = w_str.transpose(1, 2)
|
| 130 |
+
fmap_bt = fmap2d.view(B, T, C, Hp, Wp) # (B, T, 768, 14, 14)
|
| 131 |
+
|
| 132 |
+
z_art = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_art_bt[:, i]) for i in range(T)], dim=1) # (B, T, 768)
|
| 133 |
+
z_str = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_str_bt[:, i]) for i in range(T)], dim=1)
|
| 134 |
+
z_raw = fmap_bt.mean(dim=(-2, -1))
|
| 135 |
+
|
| 136 |
+
h_art = self.temporal_pool(z_art) # [B, C]
|
| 137 |
+
h_str = self.temporal_pool(z_str)
|
| 138 |
+
h_fmap = self.temporal_pool(z_raw)
|
| 139 |
+
|
| 140 |
+
q_art = self.head_art(h_art)
|
| 141 |
+
q_str = self.head_str(h_str)
|
| 142 |
+
q_fmap = self.head_fmap(h_fmap)
|
| 143 |
+
|
| 144 |
+
mode = self.ablation_mode.lower()
|
| 145 |
+
if mode == "art":
|
| 146 |
+
y_hat = q_art
|
| 147 |
+
weights = torch.tensor([1.0, 0.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
|
| 148 |
+
elif mode == "str":
|
| 149 |
+
y_hat = q_str
|
| 150 |
+
weights = torch.tensor([0.0, 1.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
|
| 151 |
+
elif mode == "raw":
|
| 152 |
+
y_hat = q_fmap
|
| 153 |
+
weights = torch.tensor([0.0, 0.0, 1.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
|
| 154 |
+
elif mode == "art+str":
|
| 155 |
+
if gate_stats is None:
|
| 156 |
+
gate_stats = two_gate_stats(w_art, w_str) # [B, 2]
|
| 157 |
+
g_ar = self.gate_two(gate_stats) # [B, 2]
|
| 158 |
+
a, b_ = torch.softmax(g_ar, dim=1).split(1, dim=1)
|
| 159 |
+
y_hat = a * q_art + b_ * q_str
|
| 160 |
+
zero = torch.zeros_like(a)
|
| 161 |
+
weights = torch.cat([a, b_, zero], dim=1)
|
| 162 |
+
elif mode == "full":
|
| 163 |
+
if gate_stats is None:
|
| 164 |
+
gate_stats = default_gate_stats(w_art, w_str, fmap_bt)
|
| 165 |
+
g = self.gate(gate_stats)
|
| 166 |
+
a, b_, c_ = torch.softmax(g, dim=1).split(1, dim=1)
|
| 167 |
+
y_hat = a * q_art + b_ * q_str + c_ * q_fmap
|
| 168 |
+
weights = torch.cat([a, b_, c_], dim=1)
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Unknown ablation_mode: {self.ablation_mode}")
|
| 171 |
+
|
| 172 |
+
aux = (
|
| 173 |
+
q_art.squeeze(1),
|
| 174 |
+
q_str.squeeze(1),
|
| 175 |
+
q_fmap.squeeze(1),
|
| 176 |
+
weights,
|
| 177 |
+
)
|
| 178 |
+
return y_hat.squeeze(1), aux
|
src/module/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src/module/__init__.py
|
src/module/compute_weight_map.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import csv
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
|
| 6 |
+
from decord import VideoReader, cpu
|
| 7 |
+
|
| 8 |
+
from module.frequency_dct import compute_twostream_dct
|
| 9 |
+
from module.read_frame_decord import (
|
| 10 |
+
sample_frames_uniform,
|
| 11 |
+
collect_needed,
|
| 12 |
+
cache_needed_frames,
|
| 13 |
+
read_window_from_cache,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def process_video(
|
| 17 |
+
vr,
|
| 18 |
+
# resize
|
| 19 |
+
size=224,
|
| 20 |
+
# anchors + window
|
| 21 |
+
num_anchors=16,
|
| 22 |
+
win=6,
|
| 23 |
+
win_step=1,
|
| 24 |
+
# DCT / weights
|
| 25 |
+
block=16,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Returns (frame_all, w_art_all, w_str_all, anchors_kept).
|
| 29 |
+
frame_all: anchor frames (RGB uint8, HxWx3)
|
| 30 |
+
w_art_all/w_str_all: maps (float32, HxW)
|
| 31 |
+
anchors_kept: frame indices used per anchor - list[list[int]]
|
| 32 |
+
"""
|
| 33 |
+
total_frames = len(vr)
|
| 34 |
+
if total_frames <= 1:
|
| 35 |
+
raise RuntimeError(f"Video too short / invalid frame count: {total_frames}")
|
| 36 |
+
|
| 37 |
+
anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
|
| 38 |
+
needed = collect_needed(anchor_idxs, total_frames, win, win_step)
|
| 39 |
+
cache = cache_needed_frames(vr, needed, size)
|
| 40 |
+
|
| 41 |
+
frame_all, w_art_all, w_str_all = [], [], []
|
| 42 |
+
anchors_kept = []
|
| 43 |
+
|
| 44 |
+
for anchor in anchor_idxs:
|
| 45 |
+
out = read_window_from_cache(cache, anchor, total_frames, win, win_step)
|
| 46 |
+
if out is None:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
anchor_frame, gray_seq, idxs = out
|
| 50 |
+
|
| 51 |
+
w_art, w_str, _dbg = compute_twostream_dct(
|
| 52 |
+
gray_seq,
|
| 53 |
+
block=block,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
frame_all.append(anchor_frame)
|
| 57 |
+
w_art_all.append(w_art.astype(np.float32, copy=False))
|
| 58 |
+
w_str_all.append(w_str.astype(np.float32, copy=False))
|
| 59 |
+
anchors_kept.append([int(x) for x in idxs])
|
| 60 |
+
|
| 61 |
+
return frame_all, w_art_all, w_str_all, anchors_kept
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_weight_map(frame_all, w_art_all, w_str_all):
|
| 65 |
+
if len(frame_all) == 0:
|
| 66 |
+
raise ValueError("No frames produced.")
|
| 67 |
+
if not (len(frame_all) == len(w_art_all) == len(w_str_all)):
|
| 68 |
+
raise ValueError(
|
| 69 |
+
f"Length mismatch: frames={len(frame_all)}, w_art_all={len(w_art_all)}, w_str_all={len(w_str_all)}"
|
| 70 |
+
)
|
| 71 |
+
frames_np = np.stack(frame_all, axis=0) # (N,H,W,3) uint8
|
| 72 |
+
w_art_np = np.stack(w_art_all, axis=0) # (N,H,W) float32
|
| 73 |
+
w_str_np = np.stack(w_str_all, axis=0) # (N,H,W) float32
|
| 74 |
+
return frames_np, w_art_np, w_str_np
|
| 75 |
+
|
| 76 |
+
def read_vid_mos_csv(csv_path):
|
| 77 |
+
rows = []
|
| 78 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
| 79 |
+
reader = csv.DictReader(f)
|
| 80 |
+
if not reader.fieldnames:
|
| 81 |
+
raise RuntimeError("CSV has no header")
|
| 82 |
+
for r in reader:
|
| 83 |
+
vid = str(r["vid"]).strip()
|
| 84 |
+
mos = float(r["mos"])
|
| 85 |
+
rows.append((vid, mos))
|
| 86 |
+
return rows
|
| 87 |
+
|
| 88 |
+
def rng_rows(rows, seed=0):
|
| 89 |
+
rng = np.random.default_rng(int(seed))
|
| 90 |
+
idx = np.arange(len(rows))
|
| 91 |
+
rng.shuffle(idx)
|
| 92 |
+
train = [rows[i] for i in idx[:]]
|
| 93 |
+
return train
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
csv_path = "/home/xinyi/Project/FD-VQA/metadata/TEST_metadata.csv"
|
| 97 |
+
db_path = "/home/xinyi/Project/FD-VQA/test_videos/"
|
| 98 |
+
# video_path = "/home/xinyi/Project/FD-VQA/test_videos/NesAirFortressIn4108.37ByTool23.mp4"
|
| 99 |
+
|
| 100 |
+
rows = read_vid_mos_csv(str(csv_path))
|
| 101 |
+
train = rng_rows(rows)
|
| 102 |
+
for i in range(len(train)):
|
| 103 |
+
vid, mos = train[i]
|
| 104 |
+
print(vid, mos)
|
| 105 |
+
# get video path
|
| 106 |
+
base_path = Path(db_path) / vid
|
| 107 |
+
video_path = None
|
| 108 |
+
for ext in ("mp4", "avi", "mkv"):
|
| 109 |
+
p = Path(str(base_path) + f".{ext}")
|
| 110 |
+
if p.exists():
|
| 111 |
+
video_path = str(p)
|
| 112 |
+
break
|
| 113 |
+
if video_path is None:
|
| 114 |
+
raise FileNotFoundError(f"Cannot find {vid} video")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# read video
|
| 118 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 119 |
+
frame_all, w_art_all, w_str_all, anchors_kept = process_video(
|
| 120 |
+
vr,
|
| 121 |
+
size=224,
|
| 122 |
+
num_anchors=16,
|
| 123 |
+
win=6,
|
| 124 |
+
win_step=1,
|
| 125 |
+
block=16,
|
| 126 |
+
)
|
| 127 |
+
frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all)
|
| 128 |
+
print("frames_np:", frames_np.shape, frames_np.dtype)
|
| 129 |
+
print("w_art_np:", w_art_np.shape, w_art_np.dtype)
|
| 130 |
+
print("w_str_np:", w_str_np.shape, w_str_np.dtype)
|
| 131 |
+
print("anchors_kept:", len(anchors_kept), "example:", anchors_kept)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
print("\n[DATA ERROR]")
|
| 134 |
+
print("idx:", i)
|
| 135 |
+
print("vid:", vid)
|
| 136 |
+
print("path:", video_path)
|
| 137 |
+
raise
|
| 138 |
+
finally:
|
| 139 |
+
# release decord video reader
|
| 140 |
+
try:
|
| 141 |
+
if vr is not None:
|
| 142 |
+
del vr
|
| 143 |
+
except Exception:
|
| 144 |
+
pass
|
src/module/frequency_dct.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
|
| 3 |
+
from decord import VideoReader, cpu
|
| 4 |
+
import cv2
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
from module.read_frame_decord import sample_frames_uniform, collect_needed, cache_needed_frames, read_window_from_cache
|
| 13 |
+
|
| 14 |
+
# ----------------------------
|
| 15 |
+
# utils
|
| 16 |
+
# ----------------------------
|
| 17 |
+
def norm01(x, eps=1e-6):
|
| 18 |
+
x = x.astype(np.float32)
|
| 19 |
+
mn, mx = float(x.min()), float(x.max())
|
| 20 |
+
return (x - mn) / (mx - mn + eps)
|
| 21 |
+
|
| 22 |
+
# def upsample_grid(grid, out_hw, interp=cv2.INTER_LINEAR):
|
| 23 |
+
def upsample_grid(grid, out_hw, interp=cv2.INTER_NEAREST):
|
| 24 |
+
H, W = out_hw
|
| 25 |
+
return cv2.resize(grid.astype(np.float32), (W, H), interpolation=interp)
|
| 26 |
+
|
| 27 |
+
def gaussian_blur(x, sigma=1.0):
|
| 28 |
+
if sigma <= 0:
|
| 29 |
+
return x
|
| 30 |
+
ksize = int(6 * sigma + 1)
|
| 31 |
+
if ksize % 2 == 0:
|
| 32 |
+
ksize += 1
|
| 33 |
+
if ksize < 3:
|
| 34 |
+
ksize = 3
|
| 35 |
+
return cv2.GaussianBlur(x, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT_101)
|
| 36 |
+
|
| 37 |
+
# sobel operator
|
| 38 |
+
def gradient_mag(gray):
|
| 39 |
+
gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
|
| 40 |
+
gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
|
| 41 |
+
dst = np.sqrt(gx * gx + gy * gy)
|
| 42 |
+
# plt.imshow(dst, cmap='gray')
|
| 43 |
+
# plt.savefig('/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only/Sobel_operator_result.jpg', dpi=300)
|
| 44 |
+
return dst
|
| 45 |
+
|
| 46 |
+
# Sobel magnitude -> normalize -> threshold -> dilate edge region
|
| 47 |
+
def edge_sobel(gray, thr=0.20, dilate_px=2):
|
| 48 |
+
g = gradient_mag(gray).astype(np.float32)
|
| 49 |
+
g = norm01(g) # normalize
|
| 50 |
+
edge = (g >= thr).astype(np.uint8)
|
| 51 |
+
if dilate_px > 0:
|
| 52 |
+
k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * dilate_px + 1, 2 * dilate_px + 1))
|
| 53 |
+
edge = cv2.dilate(edge, k, iterations=1)
|
| 54 |
+
# plt.imshow(edge, cmap='gray')
|
| 55 |
+
# plt.savefig('/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only/edge_result.jpg', dpi=300)
|
| 56 |
+
return edge
|
| 57 |
+
|
| 58 |
+
# per-block fraction of near edge pixels
|
| 59 |
+
def block_edge(edge_mask, block=16):
|
| 60 |
+
h, w = edge_mask.shape
|
| 61 |
+
gh, gw = h // block, w // block
|
| 62 |
+
frac = np.zeros((gh, gw), dtype=np.float32)
|
| 63 |
+
for by in range(gh):
|
| 64 |
+
for bx in range(gw):
|
| 65 |
+
y0, x0 = by * block, bx * block
|
| 66 |
+
frac[by, bx] = float(edge_mask[y0:y0 + block, x0:x0 + block].mean())
|
| 67 |
+
# plt.imshow(frac, cmap='gray')
|
| 68 |
+
# plt.savefig('/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only/edge_block.jpg', dpi=300)
|
| 69 |
+
return frac
|
| 70 |
+
|
| 71 |
+
# per-block mean
|
| 72 |
+
def block_mean(img, block=16):
|
| 73 |
+
h, w = img.shape
|
| 74 |
+
gh, gw = h // block, w // block
|
| 75 |
+
x = img[:gh*block, :gw*block]
|
| 76 |
+
x = x.reshape(gh, block, gw, block) # (block_row, inblock_row, block_col, inblock_col)
|
| 77 |
+
return x.mean(axis=(1, 3)).astype(np.float32)
|
| 78 |
+
|
| 79 |
+
# discontinuities across block boundaries: |I[:, x] - I[:, x-1]| & |I[y, :] - I[y-1, :]|
|
| 80 |
+
def blockiness_boundary_map(gray, block_size=16, blur_sigma=1.0):
|
| 81 |
+
h, w = gray.shape
|
| 82 |
+
v = np.zeros((h, w), dtype=np.float32)
|
| 83 |
+
hmap = np.zeros((h, w), dtype=np.float32)
|
| 84 |
+
|
| 85 |
+
for x in range(block_size, w, block_size):
|
| 86 |
+
v[:, x] = np.abs(gray[:, x] - gray[:, x - 1])
|
| 87 |
+
for y in range(block_size, h, block_size):
|
| 88 |
+
hmap[y, :] = np.abs(gray[y, :] - gray[y - 1, :])
|
| 89 |
+
b = v + hmap
|
| 90 |
+
if blur_sigma > 0:
|
| 91 |
+
b = gaussian_blur(b, blur_sigma)
|
| 92 |
+
return b
|
| 93 |
+
|
| 94 |
+
# ----------------------------
|
| 95 |
+
# DCT 16x16 energies
|
| 96 |
+
# ----------------------------
|
| 97 |
+
def dct_energy_ratios(gray, block=16, low_k=4, mid_k=8, eps=1e-6):
|
| 98 |
+
"""
|
| 99 |
+
Return per-block energy ratios for frequency bands.
|
| 100 |
+
low: [0:low_k, 0:low_k]
|
| 101 |
+
mid: [0:mid_k, 0:mid_k] - low
|
| 102 |
+
high: remaining
|
| 103 |
+
"""
|
| 104 |
+
H, W = gray.shape
|
| 105 |
+
gh, gw = H // block, W // block
|
| 106 |
+
r_low = np.zeros((gh, gw), dtype=np.float32)
|
| 107 |
+
r_mid = np.zeros((gh, gw), dtype=np.float32)
|
| 108 |
+
r_high = np.zeros((gh, gw), dtype=np.float32)
|
| 109 |
+
|
| 110 |
+
for by in range(gh):
|
| 111 |
+
for bx in range(gw):
|
| 112 |
+
y0, x0 = by * block, bx * block
|
| 113 |
+
patch = gray[y0:y0 + block, x0:x0 + block].astype(np.float32)
|
| 114 |
+
|
| 115 |
+
C = cv2.dct(patch)
|
| 116 |
+
E = C * C
|
| 117 |
+
E_low = E[:low_k, :low_k].sum() # 4 x 4
|
| 118 |
+
E_mid = E[:mid_k, :mid_k].sum() - E_low # 8×8
|
| 119 |
+
E_total = E.sum()
|
| 120 |
+
E_high = max(E_total - (E_low + E_mid), 0.0)
|
| 121 |
+
|
| 122 |
+
denom = E_total + eps
|
| 123 |
+
r_low[by, bx] = E_low / denom
|
| 124 |
+
r_mid[by, bx] = E_mid / denom
|
| 125 |
+
r_high[by, bx] = E_high / denom
|
| 126 |
+
return r_low, r_mid, r_high
|
| 127 |
+
|
| 128 |
+
# Temporal map via FFT
|
| 129 |
+
def temporal_fft_map(gray_seq, *, block=16, hf_start_bin=2, eps=1e-6):
|
| 130 |
+
# each frame -> (K, gh, gw) block mean grid
|
| 131 |
+
grids = [block_mean(g.astype(np.float32), block=block) for g in gray_seq]
|
| 132 |
+
s = np.stack(grids, axis=0)
|
| 133 |
+
|
| 134 |
+
# rFFT along time
|
| 135 |
+
X = np.fft.rfft(s, axis=0) # (F, gh, gw)
|
| 136 |
+
E = (X.real * X.real + X.imag * X.imag).astype(np.float32) # power spectrum
|
| 137 |
+
|
| 138 |
+
# F = K // 2 + 1
|
| 139 |
+
dc = E[0]
|
| 140 |
+
if E.shape[0] <= 1:
|
| 141 |
+
z = np.zeros_like(dc, dtype=np.float32)
|
| 142 |
+
return z, z
|
| 143 |
+
P = E[1:] # drop DC -> (F-1, gh, gw)
|
| 144 |
+
non_dc = P.sum(axis=0)
|
| 145 |
+
|
| 146 |
+
# changes relative to DC
|
| 147 |
+
motion = non_dc / (dc + eps)
|
| 148 |
+
|
| 149 |
+
# flicker: change is in high temporal freqs
|
| 150 |
+
start = max(int(hf_start_bin) - 1, 0) # index in P
|
| 151 |
+
if start >= P.shape[0]:
|
| 152 |
+
flicker = np.zeros_like(non_dc, dtype=np.float32)
|
| 153 |
+
else:
|
| 154 |
+
hi = P[start:].sum(axis=0)
|
| 155 |
+
flicker = hi / (non_dc + eps)
|
| 156 |
+
return motion.astype(np.float32), flicker.astype(np.float32)
|
| 157 |
+
|
| 158 |
+
def fuse_temporal_maps(motion_grid, flicker_grid, *, beta=0.5):
|
| 159 |
+
m = norm01(motion_grid)
|
| 160 |
+
f = np.clip(flicker_grid, 0.0, 1.0)
|
| 161 |
+
# boosts where flicker is high
|
| 162 |
+
w = m * ((1.0 - beta) + beta * f)
|
| 163 |
+
return norm01(w)
|
| 164 |
+
|
| 165 |
+
# ----------------------------
|
| 166 |
+
# DCT -> two stream weights
|
| 167 |
+
# ----------------------------
|
| 168 |
+
def compute_twostream_dct(
|
| 169 |
+
gray_seq,
|
| 170 |
+
*,
|
| 171 |
+
block=16,
|
| 172 |
+
):
|
| 173 |
+
K = len(gray_seq)
|
| 174 |
+
gray_anchor = gray_seq[0]
|
| 175 |
+
H, W = gray_anchor.shape
|
| 176 |
+
|
| 177 |
+
r_low_stack, r_mid_stack, r_high_stack = [], [], []
|
| 178 |
+
for g in gray_seq:
|
| 179 |
+
r_low, r_mid, r_high = dct_energy_ratios(g, block=block)
|
| 180 |
+
r_low_stack.append(r_low) # (gh, gw)
|
| 181 |
+
r_mid_stack.append(r_mid)
|
| 182 |
+
r_high_stack.append(r_high)
|
| 183 |
+
r_low_stack = np.stack(r_low_stack, axis=0) # (K, gh, gw)
|
| 184 |
+
r_mid_stack = np.stack(r_mid_stack, axis=0)
|
| 185 |
+
r_high_stack = np.stack(r_high_stack, axis=0)
|
| 186 |
+
|
| 187 |
+
# frequency band (anchor frame)
|
| 188 |
+
anchor_low_grid = r_low_stack[0] # (gh, gw)
|
| 189 |
+
anchor_mid_grid = r_mid_stack[0]
|
| 190 |
+
anchor_high_grid = r_high_stack[0]
|
| 191 |
+
|
| 192 |
+
# Ringing map (anchor)
|
| 193 |
+
edge_mask = edge_sobel(gray_anchor) # around edges
|
| 194 |
+
edge_frac = block_edge(edge_mask, block=block)
|
| 195 |
+
mh_band = r_mid_stack[0] + r_high_stack[0] # mid/high frequency energy
|
| 196 |
+
ring_score = np.maximum(mh_band, 0.0) # score = mid_high - 0 * r_low # alpha = 0
|
| 197 |
+
edge_min_frac = 0.05
|
| 198 |
+
ringing_grid = np.where(edge_frac >= edge_min_frac, edge_frac * ring_score, 0.0).astype(np.float32)
|
| 199 |
+
s = np.percentile(ringing_grid, 99) + 1e-6
|
| 200 |
+
ringing_grid01 = np.clip(ringing_grid / s, 0.0, 1.0)
|
| 201 |
+
|
| 202 |
+
# Blur map (anchor): like low-pass filtering
|
| 203 |
+
hf = 0.5 * r_mid_stack[0] + 1.0 * r_high_stack[0]
|
| 204 |
+
blur_raw = np.clip(1.0 - hf, 0.0, 1.0)
|
| 205 |
+
sobel_g = gradient_mag(gray_anchor).astype(np.float32)
|
| 206 |
+
sobel_g_grid = block_mean(sobel_g, block=block)
|
| 207 |
+
sobel_g_grid = norm01(sobel_g_grid) # soft structure weight
|
| 208 |
+
blur_grid = np.clip(blur_raw * sobel_g_grid, 0.0, 1.0)
|
| 209 |
+
|
| 210 |
+
# Blockiness map (anchor): boundary discontinuities
|
| 211 |
+
boundary_pix = blockiness_boundary_map(gray_anchor, block_size=block)
|
| 212 |
+
blockiness_grid = norm01(block_mean(boundary_pix, block=block))
|
| 213 |
+
|
| 214 |
+
# Temporal (window): FFT along time
|
| 215 |
+
if K >= 4:
|
| 216 |
+
motion_grid, flick_grid = temporal_fft_map(gray_seq, block=block, hf_start_bin=2)
|
| 217 |
+
temporal_grid = fuse_temporal_maps(motion_grid, flick_grid, beta=0.5)
|
| 218 |
+
elif K == 2:
|
| 219 |
+
E_stack = norm01(r_mid_stack + r_high_stack)
|
| 220 |
+
temporal_grid = norm01(np.abs(E_stack[1] - E_stack[0]))
|
| 221 |
+
|
| 222 |
+
# -------------Combine---------------
|
| 223 |
+
w_art = norm01(1.0 * ringing_grid01 + 1.0 * blur_grid + 1.0 * blockiness_grid + 1.0 * temporal_grid)
|
| 224 |
+
w_str = 1.0 - w_art
|
| 225 |
+
|
| 226 |
+
debug = {
|
| 227 |
+
# frequency band (anchor)
|
| 228 |
+
"dct_low_grid": anchor_low_grid,
|
| 229 |
+
"dct_mid_grid": anchor_mid_grid,
|
| 230 |
+
"dct_high_grid": anchor_high_grid,
|
| 231 |
+
# ringing (anchor)
|
| 232 |
+
"ringing_grid": ringing_grid01,
|
| 233 |
+
"edge_px": edge_mask,
|
| 234 |
+
# blur (anchor)
|
| 235 |
+
"blur_grid": blur_grid,
|
| 236 |
+
# blockiness (anchor)
|
| 237 |
+
"blockiness_grid": blockiness_grid,
|
| 238 |
+
# temporal (window)
|
| 239 |
+
"temporal_grid": temporal_grid,
|
| 240 |
+
}
|
| 241 |
+
return w_art, w_str, debug
|
| 242 |
+
|
| 243 |
+
# ----------------------------
|
| 244 |
+
# visualization panel
|
| 245 |
+
# ----------------------------
|
| 246 |
+
def save_panel(out_png, frame_rgb, w_art, w_str, debug):
|
| 247 |
+
fig = plt.figure(figsize=(16, 9), dpi=160)
|
| 248 |
+
def add(ax_i, title, img, cmap=None, vmin=0, vmax=1):
|
| 249 |
+
ax = fig.add_subplot(3, 4, ax_i)
|
| 250 |
+
ax.set_title(title)
|
| 251 |
+
if cmap is None:
|
| 252 |
+
ax.imshow(img)
|
| 253 |
+
else:
|
| 254 |
+
ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax)
|
| 255 |
+
ax.axis("off")
|
| 256 |
+
|
| 257 |
+
add(1, "Frame_t (anchor)", frame_rgb)
|
| 258 |
+
add(2, "DCT LOW (grid)", norm01(debug["dct_low_grid"]), cmap="viridis")
|
| 259 |
+
add(3, "DCT MID (grid)", norm01(debug["dct_mid_grid"]), cmap="viridis")
|
| 260 |
+
add(4, "DCT HIGH (grid)", norm01(debug["dct_high_grid"]), cmap="viridis")
|
| 261 |
+
# ringing
|
| 262 |
+
add(5, "EDGE (mask)", debug["edge_px"], cmap="viridis")
|
| 263 |
+
add(6, "RINGING (mid/high, grid)", debug["ringing_grid"], cmap="viridis")
|
| 264 |
+
# blur
|
| 265 |
+
add(7, "BLUR (lowpass, grid)", debug["blur_grid"], cmap="viridis")
|
| 266 |
+
# blockiness
|
| 267 |
+
add(8, "BLOCKINESS (boundary, grid)", debug["blockiness_grid"], cmap="viridis")
|
| 268 |
+
# temporal
|
| 269 |
+
add(9, "TEMPORAL (grid)", debug["temporal_grid"], cmap="viridis")
|
| 270 |
+
# all map
|
| 271 |
+
add(10, "W_art", w_art, cmap="viridis")
|
| 272 |
+
add(11, "W_str", w_str, cmap="viridis")
|
| 273 |
+
os.makedirs(os.path.dirname(out_png), exist_ok=True)
|
| 274 |
+
fig.tight_layout()
|
| 275 |
+
fig.savefig(out_png)
|
| 276 |
+
plt.close(fig)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# ----------------------------
|
| 280 |
+
# Main
|
| 281 |
+
# ----------------------------
|
| 282 |
+
def main():
|
| 283 |
+
parser = argparse.ArgumentParser()
|
| 284 |
+
|
| 285 |
+
parser.add_argument("--video", default="/home/xinyi/Project/FD-VQA/test_videos/SDR_Animal_5ngj.mp4")
|
| 286 |
+
parser.add_argument("--out_dir", default="/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only")
|
| 287 |
+
parser.add_argument("--size", type=int, default=224)
|
| 288 |
+
|
| 289 |
+
# fixed-T anchors over whole video (duration-uniform)
|
| 290 |
+
parser.add_argument("--num_anchors", type=int, default=16)
|
| 291 |
+
parser.add_argument("--win", type=int, default=6)
|
| 292 |
+
parser.add_argument("--win_step", type=int, default=1)
|
| 293 |
+
parser.add_argument("--block", type=int, default=16)
|
| 294 |
+
|
| 295 |
+
parser.add_argument("--no_panel", action="store_true")
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 298 |
+
|
| 299 |
+
vr = VideoReader(args.video, ctx=cpu(0))
|
| 300 |
+
total_frames = len(vr)
|
| 301 |
+
if total_frames <= 1:
|
| 302 |
+
raise RuntimeError(f"Video too short / frame count unavailable: total_frames={total_frames}")
|
| 303 |
+
print("total_frames:", total_frames)
|
| 304 |
+
|
| 305 |
+
size = args.size
|
| 306 |
+
win = args.win
|
| 307 |
+
win_step = args.win_step
|
| 308 |
+
num_anchors = args.num_anchors
|
| 309 |
+
|
| 310 |
+
anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
|
| 311 |
+
needed = collect_needed(anchor_idxs, total_frames, win, win_step)
|
| 312 |
+
print("anchor_idxs:", anchor_idxs)
|
| 313 |
+
cache = cache_needed_frames(vr, needed, size)
|
| 314 |
+
print("cached:", len(cache), "needed:", len(needed))
|
| 315 |
+
|
| 316 |
+
frame_all, w_art_all, w_str_all = [], [], []
|
| 317 |
+
anchors_kept = []
|
| 318 |
+
image_idx = 0
|
| 319 |
+
|
| 320 |
+
for anchor in tqdm(anchor_idxs, desc="Processing anchors (DCT)"):
|
| 321 |
+
out = read_window_from_cache(cache, anchor, total_frames, win, win_step)
|
| 322 |
+
if out is None:
|
| 323 |
+
continue
|
| 324 |
+
anchor_frame, gray_seq, idxs = out
|
| 325 |
+
|
| 326 |
+
w_art, w_str, dbg = compute_twostream_dct(
|
| 327 |
+
gray_seq,
|
| 328 |
+
block=args.block,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
frame_all.append(anchor_frame)
|
| 332 |
+
w_art_all.append(w_art)
|
| 333 |
+
w_str_all.append(w_str)
|
| 334 |
+
anchors_kept.append(idxs)
|
| 335 |
+
|
| 336 |
+
image_idx += 1
|
| 337 |
+
if not args.no_panel:
|
| 338 |
+
save_panel(
|
| 339 |
+
os.path.join(args.out_dir, f"anchor_{anchor:03d}_{image_idx:02d}.png"),
|
| 340 |
+
anchor_frame,
|
| 341 |
+
w_art,
|
| 342 |
+
w_str,
|
| 343 |
+
dbg,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
print(f"Done. Outputs saved to: {args.out_dir}")
|
| 347 |
+
print(anchors_kept)
|
| 348 |
+
print(f"total_frames={total_frames}, num_anchors_target={num_anchors}, anchors_produced={len(w_str_all)}")
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
main()
|
src/module/read_frame_decord.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
|
| 5 |
+
from decord import VideoReader, cpu
|
| 6 |
+
|
| 7 |
+
def sample_frames_uniform(total_frames, num_anchors, win=8, win_step=1):
|
| 8 |
+
if total_frames <= 0:
|
| 9 |
+
return [0] * num_anchors
|
| 10 |
+
|
| 11 |
+
min_win = (win - 1) * win_step + 1
|
| 12 |
+
if total_frames < num_anchors + min_win:
|
| 13 |
+
if total_frames >= num_anchors:
|
| 14 |
+
anchor_idxs = np.linspace(0, total_frames - 1, num_anchors).round().astype(int)
|
| 15 |
+
return anchor_idxs.tolist()
|
| 16 |
+
else:
|
| 17 |
+
anchor_idxs = list(range(total_frames))
|
| 18 |
+
anchor_idxs.extend([total_frames - 1] * (num_anchors - total_frames))
|
| 19 |
+
return anchor_idxs
|
| 20 |
+
|
| 21 |
+
max_start = total_frames - min_win
|
| 22 |
+
anchor_idxs = np.linspace(0, max_start, num_anchors).round().astype(int)
|
| 23 |
+
return anchor_idxs.tolist()
|
| 24 |
+
|
| 25 |
+
# window idx for anchor
|
| 26 |
+
def window_indices(anchor, total_frames, win, win_step):
|
| 27 |
+
last = total_frames - 1
|
| 28 |
+
idxs = [anchor + k * win_step for k in range(win)]
|
| 29 |
+
return [i if i < total_frames else last for i in idxs]
|
| 30 |
+
|
| 31 |
+
# collect idx for all windows
|
| 32 |
+
def collect_needed(anchor_idxs, total_frames, win, win_step):
|
| 33 |
+
needed = set()
|
| 34 |
+
for a in anchor_idxs:
|
| 35 |
+
for idx in window_indices(a, total_frames, win, win_step):
|
| 36 |
+
needed.add(idx)
|
| 37 |
+
return needed
|
| 38 |
+
|
| 39 |
+
# read frame for needed, read video in sequence
|
| 40 |
+
def cache_needed_frames(vr, needed, size):
|
| 41 |
+
if not needed:
|
| 42 |
+
return {}
|
| 43 |
+
max_idx = max(needed)
|
| 44 |
+
cache = {}
|
| 45 |
+
|
| 46 |
+
last_ok = -1
|
| 47 |
+
for i in range(max_idx + 1):
|
| 48 |
+
try:
|
| 49 |
+
# read next frame
|
| 50 |
+
frame_rgb = vr.next().asnumpy() # decord frame RGB (H,W,3)
|
| 51 |
+
except StopIteration:
|
| 52 |
+
# video ended early, like cap.read() failed
|
| 53 |
+
print(f"[DEBUG] cap.read() failed at i={i}, last_ok={last_ok}, max_idx={max_idx}")
|
| 54 |
+
print(f"[DEBUG] max_cached={max(cache.keys()) if cache else None}, last_ok={last_ok}")
|
| 55 |
+
break
|
| 56 |
+
except Exception as e:
|
| 57 |
+
# decode error
|
| 58 |
+
print(f"[DEBUG] cap.read() failed at i={i}, last_ok={last_ok}, max_idx={max_idx} | {e}")
|
| 59 |
+
print(f"[DEBUG] max_cached={max(cache.keys()) if cache else None}, last_ok={last_ok}")
|
| 60 |
+
break
|
| 61 |
+
|
| 62 |
+
last_ok = i
|
| 63 |
+
if i in needed:
|
| 64 |
+
# frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) # keep same as OpenCV: BGR
|
| 65 |
+
frame = cv2.resize(frame_rgb, (size, size), interpolation=cv2.INTER_AREA)
|
| 66 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32)
|
| 67 |
+
cache[i] = (frame, gray)
|
| 68 |
+
return cache
|
| 69 |
+
|
| 70 |
+
# each anchor: read window from cache frames
|
| 71 |
+
def read_window_from_cache(cache, anchor, total_frames, win, win_step):
|
| 72 |
+
idxs = window_indices(anchor, total_frames, win, win_step)
|
| 73 |
+
# print(f"window idxs: {idxs}")
|
| 74 |
+
if not cache:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
# fallback = last frame
|
| 78 |
+
last_avail = max(cache.keys())
|
| 79 |
+
fallback = cache[last_avail]
|
| 80 |
+
|
| 81 |
+
anchor_frame = None
|
| 82 |
+
gray_seq = []
|
| 83 |
+
for j, idx in enumerate(idxs):
|
| 84 |
+
if idx in cache:
|
| 85 |
+
data = cache[idx]
|
| 86 |
+
fallback = data
|
| 87 |
+
else:
|
| 88 |
+
data = fallback
|
| 89 |
+
|
| 90 |
+
frame, gray = data
|
| 91 |
+
if j == 0: # window [0] is anchor
|
| 92 |
+
anchor_frame = frame
|
| 93 |
+
gray_seq.append(gray)
|
| 94 |
+
|
| 95 |
+
if len(gray_seq) == 1:
|
| 96 |
+
gray_seq.append(gray_seq[0])
|
| 97 |
+
idxs.append(idxs[0])
|
| 98 |
+
return anchor_frame, gray_seq, idxs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
video_path = "/home/xinyi/Project/FD-VQA/test_videos/3369925072.mp4"
|
| 103 |
+
print("video:", video_path)
|
| 104 |
+
|
| 105 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 106 |
+
total_frames = len(vr)
|
| 107 |
+
if total_frames <= 1:
|
| 108 |
+
raise RuntimeError(f"Video too short / frame count unavailable: total_frames={total_frames}")
|
| 109 |
+
print("total_frames:", total_frames)
|
| 110 |
+
|
| 111 |
+
size = 224
|
| 112 |
+
win = 8
|
| 113 |
+
win_step = 1
|
| 114 |
+
num_anchors = 16
|
| 115 |
+
|
| 116 |
+
anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
|
| 117 |
+
needed = collect_needed(anchor_idxs, total_frames, win, win_step)
|
| 118 |
+
print("anchor_idxs:", anchor_idxs)
|
| 119 |
+
print("needed:", needed)
|
| 120 |
+
|
| 121 |
+
cache = cache_needed_frames(vr, needed, size)
|
| 122 |
+
print("cached:", len(cache), "needed:", len(needed))
|
| 123 |
+
|
| 124 |
+
frames_list, grays_list, processed_list = [], [], []
|
| 125 |
+
for a in anchor_idxs:
|
| 126 |
+
out = read_window_from_cache(cache, a, total_frames, win, win_step)
|
| 127 |
+
if out is None:
|
| 128 |
+
continue
|
| 129 |
+
anchor_frame, gray_seq, idxs = out
|
| 130 |
+
frames_list.append(anchor_frame)
|
| 131 |
+
grays_list.append(gray_seq)
|
| 132 |
+
processed_list.append(idxs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
print("frames:", len(frames_list))
|
| 136 |
+
print("one frame shape:", frames_list[0].shape)
|
| 137 |
+
print("each window length:", len(grays_list[0]) if grays_list else 0)
|
| 138 |
+
print("window idxs:", processed_list if processed_list else None)
|
src/train.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import numpy as np
|
| 7 |
+
from scipy.stats import pearsonr as _pearsonr_scipy, spearmanr as _spearmanr_scipy
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import cv2
|
| 11 |
+
from decord import VideoReader, cpu
|
| 12 |
+
from torch.amp import GradScaler, autocast
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from module.compute_weight_map import process_video, compute_weight_map
|
| 16 |
+
from model.qd_model import QD_MODEL
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ----------------------------
|
| 20 |
+
# Data utils
|
| 21 |
+
# ----------------------------
|
| 22 |
+
def read_vid_mos_csv(csv_path):
|
| 23 |
+
rows = []
|
| 24 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
| 25 |
+
reader = csv.DictReader(f)
|
| 26 |
+
if not reader.fieldnames:
|
| 27 |
+
raise RuntimeError("CSV has no header")
|
| 28 |
+
for r in reader:
|
| 29 |
+
vid = str(r["vid"]).strip()
|
| 30 |
+
mos = float(r["mos"])
|
| 31 |
+
rows.append((vid, mos))
|
| 32 |
+
return rows
|
| 33 |
+
|
| 34 |
+
def split_rows(rows, seed=42, test_ratio=0.2, val_ratio=0.1):
|
| 35 |
+
rng = np.random.default_rng(int(seed))
|
| 36 |
+
idx = np.arange(len(rows))
|
| 37 |
+
rng.shuffle(idx)
|
| 38 |
+
|
| 39 |
+
n = len(rows)
|
| 40 |
+
n_test = int(round(n * test_ratio))
|
| 41 |
+
n_train_all = n - n_test # train+val
|
| 42 |
+
n_val = int(round(n_train_all * val_ratio)) # val from train_all
|
| 43 |
+
|
| 44 |
+
val = [rows[i] for i in idx[:n_val]]
|
| 45 |
+
train = [rows[i] for i in idx[n_val:n_train_all]]
|
| 46 |
+
test = [rows[i] for i in idx[n_train_all:]]
|
| 47 |
+
return train, val, test
|
| 48 |
+
|
| 49 |
+
def split_train_val(rows, seed=42, val_ratio=0.1):
|
| 50 |
+
rng = np.random.default_rng(int(seed))
|
| 51 |
+
idx = np.arange(len(rows))
|
| 52 |
+
rng.shuffle(idx)
|
| 53 |
+
|
| 54 |
+
n = len(rows)
|
| 55 |
+
n_val = int(round(n * val_ratio))
|
| 56 |
+
val = [rows[i] for i in idx[:n_val]]
|
| 57 |
+
train = [rows[i] for i in idx[n_val:]]
|
| 58 |
+
return train, val
|
| 59 |
+
|
| 60 |
+
def pearsonr(x, y, eps=1e-12):
|
| 61 |
+
# PLCC (SciPy): returns a torch scalar tensor so call-site ".item()" still works
|
| 62 |
+
if hasattr(x, "detach"):
|
| 63 |
+
x = x.detach().cpu().numpy()
|
| 64 |
+
if hasattr(y, "detach"):
|
| 65 |
+
y = y.detach().cpu().numpy()
|
| 66 |
+
x = np.asarray(x).reshape(-1)
|
| 67 |
+
y = np.asarray(y).reshape(-1)
|
| 68 |
+
|
| 69 |
+
# avoid NaN when constant / too short
|
| 70 |
+
if x.size < 2 or np.std(x) < eps or np.std(y) < eps:
|
| 71 |
+
return torch.tensor(0.0)
|
| 72 |
+
r, _p = _pearsonr_scipy(x, y)
|
| 73 |
+
if np.isnan(r):
|
| 74 |
+
r = 0.0
|
| 75 |
+
return torch.tensor(float(r))
|
| 76 |
+
|
| 77 |
+
def spearmanr(x, y, eps=1e-12):
|
| 78 |
+
# SRCC (SciPy): handles ties correctly; returns torch scalar tensor
|
| 79 |
+
if hasattr(x, "detach"):
|
| 80 |
+
x = x.detach().cpu().numpy()
|
| 81 |
+
if hasattr(y, "detach"):
|
| 82 |
+
y = y.detach().cpu().numpy()
|
| 83 |
+
|
| 84 |
+
x = np.asarray(x).reshape(-1)
|
| 85 |
+
y = np.asarray(y).reshape(-1)
|
| 86 |
+
if x.size < 2 or np.std(x) < eps or np.std(y) < eps:
|
| 87 |
+
return torch.tensor(0.0)
|
| 88 |
+
|
| 89 |
+
r, _p = _spearmanr_scipy(x, y)
|
| 90 |
+
if np.isnan(r):
|
| 91 |
+
r = 0.0
|
| 92 |
+
return torch.tensor(float(r))
|
| 93 |
+
|
| 94 |
+
# ----------------------------
|
| 95 |
+
# Train utils
|
| 96 |
+
# ----------------------------
|
| 97 |
+
def build_scheduler(optim, args):
|
| 98 |
+
warm = int(args.warmup_epochs)
|
| 99 |
+
total = int(args.epochs)
|
| 100 |
+
warm = max(0, min(warm, total - 1))
|
| 101 |
+
|
| 102 |
+
# warmup
|
| 103 |
+
warmup = torch.optim.lr_scheduler.LinearLR(
|
| 104 |
+
optim,
|
| 105 |
+
start_factor=0.1,
|
| 106 |
+
total_iters=warm if warm > 0 else 1,
|
| 107 |
+
)
|
| 108 |
+
# cosine warmup
|
| 109 |
+
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 110 |
+
optim,
|
| 111 |
+
T_max=(total - warm) if (total - warm) > 0 else 1,
|
| 112 |
+
eta_min=float(args.min_lr),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if warm > 0:
|
| 116 |
+
return torch.optim.lr_scheduler.SequentialLR(
|
| 117 |
+
optim,
|
| 118 |
+
schedulers=[warmup, cosine],
|
| 119 |
+
milestones=[warm],
|
| 120 |
+
)
|
| 121 |
+
return cosine
|
| 122 |
+
|
| 123 |
+
def com_loss(y_pred, y_true, reg_w=0.6, rank_w=1.0, huber_beta=1.0, margin=0.0):
|
| 124 |
+
# 1) Huber / SmoothL1
|
| 125 |
+
if huber_beta is None:
|
| 126 |
+
reg_loss = F.l1_loss(y_pred, y_true, reduction="mean")
|
| 127 |
+
else:
|
| 128 |
+
reg_loss = F.smooth_l1_loss(y_pred, y_true, beta=float(huber_beta), reduction="mean")
|
| 129 |
+
reg_loss = reg_loss * float(reg_w)
|
| 130 |
+
|
| 131 |
+
# 2) pairwise hinge rank
|
| 132 |
+
B = y_true.shape[0]
|
| 133 |
+
if B < 2 or float(rank_w) == 0.0:
|
| 134 |
+
rank_loss = y_pred.new_tensor(0.0)
|
| 135 |
+
return reg_loss + rank_loss, reg_loss, rank_loss
|
| 136 |
+
pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0) # [B,B]
|
| 137 |
+
true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0) # [B,B]
|
| 138 |
+
s = torch.sign(true_diff) # -1,0,+1
|
| 139 |
+
|
| 140 |
+
# y_true = pair, ignore
|
| 141 |
+
mask = (s != 0).float()
|
| 142 |
+
# hinge: max(0, margin - s*(pred_i - pred_j))
|
| 143 |
+
rank_mat = F.relu(float(margin) - s * pred_diff) * mask
|
| 144 |
+
denom = mask.sum().clamp_min(1.0)
|
| 145 |
+
rank_loss = rank_mat.sum() / denom
|
| 146 |
+
rank_loss = rank_loss * float(rank_w)
|
| 147 |
+
|
| 148 |
+
total = reg_loss + rank_loss
|
| 149 |
+
return total, reg_loss, rank_loss
|
| 150 |
+
|
| 151 |
+
# ----------------------------
|
| 152 |
+
# Dataset
|
| 153 |
+
# ----------------------------
|
| 154 |
+
class VQADataset(torch.utils.data.Dataset):
|
| 155 |
+
"""
|
| 156 |
+
Returns per item:
|
| 157 |
+
rgb: [3, T, H, W] float in [0,1] (RGB)
|
| 158 |
+
w_art: [1, T, H, W] float in [0,1]
|
| 159 |
+
w_str: [1, T, H, W] float in [0,1]
|
| 160 |
+
y: scalar float (MOS, optional normalized)
|
| 161 |
+
vid: str
|
| 162 |
+
"""
|
| 163 |
+
def __init__(self, rows, db_path, clip_len, size, win, win_step, mos_mean=None, mos_std=None):
|
| 164 |
+
self.rows = rows
|
| 165 |
+
self.db_path = str(db_path)
|
| 166 |
+
self.clip_len = int(clip_len)
|
| 167 |
+
self.size = int(size)
|
| 168 |
+
self.win = int(win)
|
| 169 |
+
self.win_step = int(win_step)
|
| 170 |
+
self.mos_mean = mos_mean
|
| 171 |
+
self.mos_std = mos_std
|
| 172 |
+
|
| 173 |
+
def __len__(self):
|
| 174 |
+
return len(self.rows)
|
| 175 |
+
|
| 176 |
+
def __getitem__(self, idx):
|
| 177 |
+
vid, mos = self.rows[int(idx)]
|
| 178 |
+
num_anchors = self.clip_len
|
| 179 |
+
size = self.size
|
| 180 |
+
win = self.win
|
| 181 |
+
win_step = self.win_step
|
| 182 |
+
|
| 183 |
+
# get video path
|
| 184 |
+
base_path = Path(self.db_path) / vid
|
| 185 |
+
video_path = None
|
| 186 |
+
for ext in ("mp4", "avi", "mkv"):
|
| 187 |
+
p = Path(str(base_path) + f".{ext}")
|
| 188 |
+
if p.exists():
|
| 189 |
+
video_path = str(p)
|
| 190 |
+
break
|
| 191 |
+
if video_path is None:
|
| 192 |
+
raise FileNotFoundError(f"Cannot find {vid} video")
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
# read video
|
| 196 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 197 |
+
frame_all, w_art_all, w_str_all, anchors_kept = process_video(
|
| 198 |
+
vr,
|
| 199 |
+
size=size, num_anchors=num_anchors, win=win, win_step=win_step,
|
| 200 |
+
)
|
| 201 |
+
frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all)
|
| 202 |
+
# print("frames_np:", frames_np.shape, frames_np.dtype)
|
| 203 |
+
# print("w_art_np:", w_art_np.shape, w_art_np.dtype)
|
| 204 |
+
# print("w_str_np:", w_str_np.shape, w_str_np.dtype)
|
| 205 |
+
# print("anchors_kept:", len(anchors_kept), "example:", anchors_kept[0])
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print("\n[DATA ERROR]")
|
| 208 |
+
print("idx:", idx)
|
| 209 |
+
print("vid:", vid)
|
| 210 |
+
raise
|
| 211 |
+
finally:
|
| 212 |
+
# release decord video reader
|
| 213 |
+
try:
|
| 214 |
+
if vr is not None:
|
| 215 |
+
del vr
|
| 216 |
+
except Exception:
|
| 217 |
+
pass
|
| 218 |
+
|
| 219 |
+
# fixed length sampling
|
| 220 |
+
T = self.clip_len
|
| 221 |
+
|
| 222 |
+
# frames_sel = [cv2.cvtColor(frames_np[i], cv2.COLOR_BGR2RGB) for i in range(T)] # frame BGR -> RGB: [T,H,W,3] -> [3,T,H,W]
|
| 223 |
+
frames_sel = [frames_np[i] for i in range(T)] # RGB frames_np
|
| 224 |
+
rgb = torch.from_numpy(np.stack(frames_sel, axis=0)).float()
|
| 225 |
+
rgb = rgb.permute(3, 0, 1, 2).contiguous() / 255.0
|
| 226 |
+
|
| 227 |
+
# W_art / W_str: [T,H,W] -> [1,T,H,W]
|
| 228 |
+
w_art = torch.from_numpy(np.stack([w_art_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float()
|
| 229 |
+
w_str = torch.from_numpy(np.stack([w_str_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float()
|
| 230 |
+
|
| 231 |
+
# MOS
|
| 232 |
+
y = float(mos)
|
| 233 |
+
if self.mos_mean is not None and self.mos_std is not None:
|
| 234 |
+
y = (y - self.mos_mean) / (self.mos_std + 1e-8)
|
| 235 |
+
y = torch.tensor(y).float()
|
| 236 |
+
return rgb, w_art, w_str, y, str(vid)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ----------------------------
|
| 240 |
+
# Train
|
| 241 |
+
# ----------------------------
|
| 242 |
+
@torch.no_grad()
|
| 243 |
+
def _gather_cat(xs):
|
| 244 |
+
if not xs:
|
| 245 |
+
return torch.empty(0)
|
| 246 |
+
return torch.cat(xs, dim=0)
|
| 247 |
+
|
| 248 |
+
def run_epoch(model, loader, device, *, optim=None, amp=True, mos_mean=None, mos_std=None, desc="", show_pbar=True, log_interval=10):
|
| 249 |
+
is_train = optim is not None
|
| 250 |
+
model.train(is_train)
|
| 251 |
+
|
| 252 |
+
scaler = getattr(run_epoch, "_scaler", None)
|
| 253 |
+
if scaler is None:
|
| 254 |
+
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
|
| 255 |
+
run_epoch._scaler = GradScaler(device_type, enabled=(amp and device_type == "cuda"))
|
| 256 |
+
scaler = run_epoch._scaler
|
| 257 |
+
|
| 258 |
+
losses = []
|
| 259 |
+
y_all = []
|
| 260 |
+
yhat_all = []
|
| 261 |
+
|
| 262 |
+
# ---- tqdm progress bar ----
|
| 263 |
+
it = loader
|
| 264 |
+
if show_pbar:
|
| 265 |
+
it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True)
|
| 266 |
+
|
| 267 |
+
for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1):
|
| 268 |
+
rgb = rgb.to(device, non_blocking=True) # [B,3,T,H,W]
|
| 269 |
+
w_art = w_art.to(device, non_blocking=True) # [B,1,T,H,W]
|
| 270 |
+
w_str = w_str.to(device, non_blocking=True) # [B,1,T,H,W]
|
| 271 |
+
y = y.to(device, non_blocking=True).float() # [B]
|
| 272 |
+
|
| 273 |
+
if is_train:
|
| 274 |
+
optim.zero_grad(set_to_none=True)
|
| 275 |
+
|
| 276 |
+
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
|
| 277 |
+
if is_train:
|
| 278 |
+
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
|
| 279 |
+
yhat, _aux = model(rgb, w_art, w_str) # yhat: [B]
|
| 280 |
+
loss, loss_reg, loss_rank = com_loss(yhat, y)
|
| 281 |
+
|
| 282 |
+
scaler.scale(loss).backward()
|
| 283 |
+
scaler.step(optim)
|
| 284 |
+
scaler.update()
|
| 285 |
+
else:
|
| 286 |
+
with torch.inference_mode():
|
| 287 |
+
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
|
| 288 |
+
yhat, _aux = model(rgb, w_art, w_str)
|
| 289 |
+
loss, loss_reg, loss_rank = com_loss(yhat, y)
|
| 290 |
+
|
| 291 |
+
loss_cpu = loss.detach().float().cpu()
|
| 292 |
+
losses.append(loss_cpu)
|
| 293 |
+
y_all.append(y.detach().float().cpu())
|
| 294 |
+
yhat_all.append(yhat.detach().float().cpu())
|
| 295 |
+
|
| 296 |
+
# ---- update bar every log_interval steps ----
|
| 297 |
+
if show_pbar and (step % int(log_interval) == 0 or step == len(loader)):
|
| 298 |
+
avg_loss_so_far = torch.stack(losses).mean().item()
|
| 299 |
+
lrs = None
|
| 300 |
+
if is_train and hasattr(optim, "param_groups") and optim.param_groups:
|
| 301 |
+
lrs = [pg.get("lr", None) for pg in optim.param_groups]
|
| 302 |
+
|
| 303 |
+
postfix = {"loss": f"{avg_loss_so_far:.4f}"}
|
| 304 |
+
if lrs is not None:
|
| 305 |
+
postfix["lrs"] = ",".join([f"{x:.2e}" for x in lrs if x is not None])
|
| 306 |
+
it.set_postfix(postfix)
|
| 307 |
+
|
| 308 |
+
y_all = _gather_cat(y_all)
|
| 309 |
+
yhat_all = _gather_cat(yhat_all)
|
| 310 |
+
|
| 311 |
+
# ---- 反标准化:在 MOS 原尺度上算相关系数 ----
|
| 312 |
+
if mos_mean is not None and mos_std is not None:
|
| 313 |
+
y_all = y_all * mos_std + mos_mean
|
| 314 |
+
yhat_all = yhat_all * mos_std + mos_mean
|
| 315 |
+
|
| 316 |
+
plcc = pearsonr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0
|
| 317 |
+
srcc = spearmanr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0
|
| 318 |
+
rmse = torch.sqrt(torch.mean((yhat_all - y_all) ** 2)).item() if y_all.numel() > 0 else 0.0
|
| 319 |
+
|
| 320 |
+
avg_loss = torch.stack(losses).mean().item() if losses else 0.0
|
| 321 |
+
return avg_loss, plcc, srcc, rmse
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def main():
|
| 325 |
+
ap = argparse.ArgumentParser()
|
| 326 |
+
# ----- data -----
|
| 327 |
+
ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/LSVQ_TRAIN_metadata.csv")
|
| 328 |
+
ap.add_argument("--db_path", default="/media/xinyi/server/LSVQ/")
|
| 329 |
+
ap.add_argument("--split_seed", type=int, default=42)
|
| 330 |
+
ap.add_argument("--test_ratio", type=float, default=0.2)
|
| 331 |
+
ap.add_argument("--val_ratio", type=float, default=0.1) # train 80%
|
| 332 |
+
# ----- video processing -----
|
| 333 |
+
ap.add_argument("--clip_len", type=int, default=16)
|
| 334 |
+
ap.add_argument("--resize", type=int, default=224)
|
| 335 |
+
ap.add_argument("--win", type=int, default=6)
|
| 336 |
+
ap.add_argument("--win_step", type=int, default=1)
|
| 337 |
+
# ----- runtime -----
|
| 338 |
+
ap.add_argument("--batch_size", type=int, default=8)
|
| 339 |
+
ap.add_argument("--num_workers", type=int, default=2)
|
| 340 |
+
ap.add_argument("--device", type=str, default="cuda")
|
| 341 |
+
ap.add_argument("--no_amp", action="store_true")
|
| 342 |
+
# ----- hyperparams -----
|
| 343 |
+
ap.add_argument("--epochs", type=int, default=35)
|
| 344 |
+
ap.add_argument("--warmup_epochs", type=int, default=3)
|
| 345 |
+
ap.add_argument("--lr", type=float, default=1e-5)
|
| 346 |
+
ap.add_argument("--min_lr", type=float, default=1e-6)
|
| 347 |
+
ap.add_argument("--finetune_lr", type=float, default=5e-5)
|
| 348 |
+
ap.add_argument("--weight_decay", type=float, default=1e-2)
|
| 349 |
+
ap.add_argument("--clip_unfreeze_blocks", type=int, default=4)
|
| 350 |
+
ap.add_argument("--finetune_last_stage", action="store_true")
|
| 351 |
+
ap.add_argument("--patience", type=int, default=6)
|
| 352 |
+
# ----- save -----
|
| 353 |
+
ap.add_argument("--save_dir", type=str, default="checkpoints")
|
| 354 |
+
ap.add_argument("--save_name", type=str, default="qd_model.pt")
|
| 355 |
+
|
| 356 |
+
args = ap.parse_args()
|
| 357 |
+
torch.manual_seed(args.split_seed)
|
| 358 |
+
device = torch.device(args.device)
|
| 359 |
+
amp = not bool(args.no_amp)
|
| 360 |
+
|
| 361 |
+
# ----------------------------
|
| 362 |
+
# Load rows and split
|
| 363 |
+
# ----------------------------
|
| 364 |
+
csv_path = Path(args.csv_path)
|
| 365 |
+
if csv_path.name == "LSVQ_TRAIN_metadata.csv":
|
| 366 |
+
# LSVQ official split
|
| 367 |
+
test_csv = csv_path.parent / "LSVQ_TEST_metadata.csv"
|
| 368 |
+
if not test_csv.exists():
|
| 369 |
+
raise FileNotFoundError(f"Cannot find LSVQ test csv: {test_csv}")
|
| 370 |
+
train_all = read_vid_mos_csv(str(csv_path))
|
| 371 |
+
test_rows = read_vid_mos_csv(str(test_csv))
|
| 372 |
+
train_rows, val_rows = split_train_val(
|
| 373 |
+
train_all,
|
| 374 |
+
seed=args.split_seed,
|
| 375 |
+
val_ratio=args.val_ratio,
|
| 376 |
+
)
|
| 377 |
+
print(f"[LSVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
|
| 378 |
+
elif csv_path.name == "KVQ_TRAIN_metadata.csv":
|
| 379 |
+
# KVQ challenge split
|
| 380 |
+
val_csv = csv_path.parent / "KVQ_VAL_metadata.csv"
|
| 381 |
+
test_csv = csv_path.parent / "KVQ_TEST_metadata.csv"
|
| 382 |
+
train_rows = read_vid_mos_csv(str(csv_path))
|
| 383 |
+
val_rows = read_vid_mos_csv(str(val_csv))
|
| 384 |
+
test_rows = read_vid_mos_csv(str(test_csv))
|
| 385 |
+
print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
|
| 386 |
+
else:
|
| 387 |
+
# default split for other datasets
|
| 388 |
+
rows = read_vid_mos_csv(str(csv_path))
|
| 389 |
+
train_rows, val_rows, test_rows = split_rows(
|
| 390 |
+
rows,
|
| 391 |
+
seed=args.split_seed,
|
| 392 |
+
test_ratio=args.test_ratio,
|
| 393 |
+
val_ratio=args.val_ratio,
|
| 394 |
+
)
|
| 395 |
+
# print("sizes:", len(rows), len(train_rows), len(val_rows), len(test_rows))
|
| 396 |
+
# print("train first 3:", train_rows[:3])
|
| 397 |
+
# print("val first 3:", val_rows[:3])
|
| 398 |
+
# print("test first 3:", test_rows[:3])
|
| 399 |
+
|
| 400 |
+
# MOS normalization stats from train split
|
| 401 |
+
mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32)
|
| 402 |
+
mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0
|
| 403 |
+
mos_std = float(mos_train.std()) if len(mos_train) else 1.0
|
| 404 |
+
if mos_std <= 1e-8:
|
| 405 |
+
mos_std = 1.0
|
| 406 |
+
|
| 407 |
+
# ----------------------------
|
| 408 |
+
# DB and datasets
|
| 409 |
+
# ----------------------------
|
| 410 |
+
ds_train = VQADataset(
|
| 411 |
+
train_rows, args.db_path,
|
| 412 |
+
clip_len=args.clip_len,
|
| 413 |
+
size=args.resize,
|
| 414 |
+
win=args.win,
|
| 415 |
+
win_step=args.win_step,
|
| 416 |
+
mos_mean=mos_mean,
|
| 417 |
+
mos_std=mos_std,
|
| 418 |
+
)
|
| 419 |
+
ds_val = VQADataset(
|
| 420 |
+
val_rows, args.db_path,
|
| 421 |
+
clip_len=args.clip_len,
|
| 422 |
+
size=args.resize,
|
| 423 |
+
win=args.win,
|
| 424 |
+
win_step=args.win_step,
|
| 425 |
+
mos_mean=mos_mean,
|
| 426 |
+
mos_std=mos_std,
|
| 427 |
+
)
|
| 428 |
+
ds_test = VQADataset(
|
| 429 |
+
test_rows, args.db_path,
|
| 430 |
+
clip_len=args.clip_len,
|
| 431 |
+
size=args.resize,
|
| 432 |
+
win=args.win,
|
| 433 |
+
win_step=args.win_step,
|
| 434 |
+
mos_mean=mos_mean,
|
| 435 |
+
mos_std=mos_std,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
pin = str(device).startswith("cuda")
|
| 439 |
+
|
| 440 |
+
loader_train = torch.utils.data.DataLoader(
|
| 441 |
+
ds_train,
|
| 442 |
+
batch_size=args.batch_size,
|
| 443 |
+
shuffle=True,
|
| 444 |
+
num_workers=args.num_workers,
|
| 445 |
+
pin_memory=pin,
|
| 446 |
+
# persistent_workers=(args.num_workers > 0),
|
| 447 |
+
prefetch_factor=4 if args.num_workers > 0 else None,
|
| 448 |
+
drop_last=False,
|
| 449 |
+
)
|
| 450 |
+
loader_val = torch.utils.data.DataLoader(
|
| 451 |
+
ds_val,
|
| 452 |
+
batch_size=args.batch_size,
|
| 453 |
+
shuffle=False,
|
| 454 |
+
num_workers=args.num_workers,
|
| 455 |
+
pin_memory=pin,
|
| 456 |
+
drop_last=False,
|
| 457 |
+
)
|
| 458 |
+
loader_test = torch.utils.data.DataLoader(
|
| 459 |
+
ds_test,
|
| 460 |
+
batch_size=args.batch_size,
|
| 461 |
+
shuffle=False,
|
| 462 |
+
num_workers=args.num_workers,
|
| 463 |
+
pin_memory=pin,
|
| 464 |
+
drop_last=False,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# ----------------------------
|
| 468 |
+
# Model
|
| 469 |
+
# ----------------------------
|
| 470 |
+
model = QD_MODEL(
|
| 471 |
+
clip_model="openai/clip-vit-base-patch16",
|
| 472 |
+
).to(device)
|
| 473 |
+
|
| 474 |
+
# Stage A: freeze CLIP
|
| 475 |
+
model.freeze_clip_all()
|
| 476 |
+
|
| 477 |
+
clip_params_all = []
|
| 478 |
+
other_params_all = []
|
| 479 |
+
for name, p in model.named_parameters():
|
| 480 |
+
if name.startswith("encoder."):
|
| 481 |
+
clip_params_all.append(p)
|
| 482 |
+
else:
|
| 483 |
+
other_params_all.append(p)
|
| 484 |
+
|
| 485 |
+
param_groups = []
|
| 486 |
+
if other_params_all:
|
| 487 |
+
param_groups.append({"params": other_params_all, "lr": float(args.lr)})
|
| 488 |
+
if clip_params_all:
|
| 489 |
+
param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)})
|
| 490 |
+
|
| 491 |
+
optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay))
|
| 492 |
+
scheduler = build_scheduler(optim, args)
|
| 493 |
+
# ----------------------------
|
| 494 |
+
# Train loop
|
| 495 |
+
# ----------------------------
|
| 496 |
+
save_dir = Path(args.save_dir)
|
| 497 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 498 |
+
save_path = save_dir / args.save_name
|
| 499 |
+
best_path = save_dir / (Path(args.save_name).stem + ".best.pt")
|
| 500 |
+
best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt")
|
| 501 |
+
|
| 502 |
+
best_val_srcc = -1e18
|
| 503 |
+
bad_epochs = 0
|
| 504 |
+
did_unfreeze = False
|
| 505 |
+
|
| 506 |
+
for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Epochs", dynamic_ncols=True):
|
| 507 |
+
|
| 508 |
+
# Stage B: optional CLIP finetune after warmup
|
| 509 |
+
if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1):
|
| 510 |
+
model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True)
|
| 511 |
+
did_unfreeze = True
|
| 512 |
+
|
| 513 |
+
# 不重建 optim, 保留 Adam 状态,只更新 lr
|
| 514 |
+
if hasattr(optim, "param_groups") and len(optim.param_groups) >= 2:
|
| 515 |
+
optim.param_groups[0]["lr"] = float(args.lr)
|
| 516 |
+
optim.param_groups[1]["lr"] = float(args.finetune_lr)
|
| 517 |
+
|
| 518 |
+
print(
|
| 519 |
+
f"[Stage B] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks | "
|
| 520 |
+
f"lr={float(args.lr)} finetune_lr={float(args.finetune_lr)}"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch(
|
| 524 |
+
model, loader_train, device,
|
| 525 |
+
optim=optim,
|
| 526 |
+
amp=amp,
|
| 527 |
+
mos_mean=mos_mean,
|
| 528 |
+
mos_std=mos_std,
|
| 529 |
+
desc=f"Train e{epoch:03d}",
|
| 530 |
+
show_pbar=True,
|
| 531 |
+
log_interval=10,
|
| 532 |
+
)
|
| 533 |
+
va_loss, va_plcc, va_srcc, va_rmse = run_epoch(
|
| 534 |
+
model, loader_val, device,
|
| 535 |
+
optim=None,
|
| 536 |
+
amp=amp,
|
| 537 |
+
mos_mean=mos_mean,
|
| 538 |
+
mos_std=mos_std,
|
| 539 |
+
desc=f"Val e{epoch:03d}",
|
| 540 |
+
show_pbar=True,
|
| 541 |
+
log_interval=10,
|
| 542 |
+
)
|
| 543 |
+
scheduler.step()
|
| 544 |
+
print(
|
| 545 |
+
f"epoch {epoch:03d} | "
|
| 546 |
+
f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | "
|
| 547 |
+
f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}"
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# Save "last" checkpoint every epoch
|
| 551 |
+
ckpt = {
|
| 552 |
+
"epoch": epoch,
|
| 553 |
+
"model": model.state_dict(),
|
| 554 |
+
"optim": optim.state_dict(),
|
| 555 |
+
"mos_mean": mos_mean,
|
| 556 |
+
"mos_std": mos_std,
|
| 557 |
+
"args": vars(args),
|
| 558 |
+
"best_val_srcc": best_val_srcc,
|
| 559 |
+
}
|
| 560 |
+
torch.save(ckpt, str(save_path))
|
| 561 |
+
|
| 562 |
+
# Save best by val SRCC (higher is better)
|
| 563 |
+
if va_srcc > best_val_srcc:
|
| 564 |
+
best_val_srcc = va_srcc
|
| 565 |
+
bad_epochs = 0
|
| 566 |
+
ckpt["best_val_srcc"] = best_val_srcc
|
| 567 |
+
torch.save(ckpt, str(best_path))
|
| 568 |
+
torch.save(model.state_dict(), str(best_weights_path))
|
| 569 |
+
print(
|
| 570 |
+
f" [best] val_srcc={best_val_srcc:.4f} (val_rmse={va_rmse:.4f}) -> saved "
|
| 571 |
+
f"{best_path} and {best_weights_path}"
|
| 572 |
+
)
|
| 573 |
+
else:
|
| 574 |
+
bad_epochs += 1
|
| 575 |
+
if bad_epochs >= int(args.patience):
|
| 576 |
+
print(
|
| 577 |
+
f"[EarlyStop] val_srcc did not improve for {bad_epochs} epochs. "
|
| 578 |
+
f"Stop at epoch {epoch}."
|
| 579 |
+
)
|
| 580 |
+
break
|
| 581 |
+
# ----------------------------
|
| 582 |
+
# Test (load best)
|
| 583 |
+
# ----------------------------
|
| 584 |
+
if best_weights_path.exists():
|
| 585 |
+
sd = torch.load(str(best_weights_path), map_location=device, weights_only=True)
|
| 586 |
+
model.load_state_dict(sd, strict=True)
|
| 587 |
+
print(f"Loaded best weights: {best_weights_path}")
|
| 588 |
+
elif best_path.exists():
|
| 589 |
+
best = torch.load(str(best_path), map_location=device)
|
| 590 |
+
model.load_state_dict(best["model"], strict=True)
|
| 591 |
+
print(f"Loaded best checkpoint: {best_path} (val_srcc={best.get('best_val_srcc', None)})")
|
| 592 |
+
|
| 593 |
+
te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
|
| 594 |
+
model, loader_test, device,
|
| 595 |
+
optim=None,
|
| 596 |
+
amp=amp,
|
| 597 |
+
mos_mean=mos_mean,
|
| 598 |
+
mos_std=mos_std,
|
| 599 |
+
)
|
| 600 |
+
print(f"TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
if __name__ == "__main__":
|
| 604 |
+
main()
|
src/transfer.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from train import (
|
| 7 |
+
read_vid_mos_csv,
|
| 8 |
+
split_rows,
|
| 9 |
+
VQADataset,
|
| 10 |
+
run_epoch,
|
| 11 |
+
build_scheduler,
|
| 12 |
+
)
|
| 13 |
+
from model.qd_model import QD_MODEL
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_pretrained_weights(model, pretrained_path, device):
|
| 17 |
+
p = Path(pretrained_path)
|
| 18 |
+
obj = torch.load(str(p), map_location=device, weights_only=True)
|
| 19 |
+
# *.pt: dict checkpoint
|
| 20 |
+
if isinstance(obj, dict) and "model" in obj:
|
| 21 |
+
model.load_state_dict(obj["model"], strict=True)
|
| 22 |
+
return obj
|
| 23 |
+
# best_weights.pt: state_dict
|
| 24 |
+
else:
|
| 25 |
+
model.load_state_dict(obj, strict=True)
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
def make_loaders(rows_train, rows_val, rows_test, args, mos_mean, mos_std, device):
|
| 29 |
+
ds_train = VQADataset(
|
| 30 |
+
rows_train, args.db_path,
|
| 31 |
+
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
|
| 32 |
+
mos_mean=mos_mean, mos_std=mos_std,
|
| 33 |
+
)
|
| 34 |
+
ds_val = VQADataset(
|
| 35 |
+
rows_val, args.db_path,
|
| 36 |
+
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
|
| 37 |
+
mos_mean=mos_mean, mos_std=mos_std,
|
| 38 |
+
)
|
| 39 |
+
ds_test = VQADataset(
|
| 40 |
+
rows_test, args.db_path,
|
| 41 |
+
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
|
| 42 |
+
mos_mean=mos_mean, mos_std=mos_std,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
pin = str(device).startswith("cuda")
|
| 46 |
+
loader_train = torch.utils.data.DataLoader(
|
| 47 |
+
ds_train, batch_size=args.batch_size, shuffle=True,
|
| 48 |
+
num_workers=args.num_workers, pin_memory=pin,
|
| 49 |
+
persistent_workers=(args.num_workers > 0),
|
| 50 |
+
prefetch_factor=4 if args.num_workers > 0 else None,
|
| 51 |
+
drop_last=False,
|
| 52 |
+
)
|
| 53 |
+
loader_val = torch.utils.data.DataLoader(
|
| 54 |
+
ds_val, batch_size=args.batch_size, shuffle=False,
|
| 55 |
+
num_workers=args.num_workers, pin_memory=pin, drop_last=False,
|
| 56 |
+
)
|
| 57 |
+
loader_test = torch.utils.data.DataLoader(
|
| 58 |
+
ds_test, batch_size=args.batch_size, shuffle=False,
|
| 59 |
+
num_workers=args.num_workers, pin_memory=pin, drop_last=False,
|
| 60 |
+
)
|
| 61 |
+
return loader_train, loader_val, loader_test
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def main():
|
| 65 |
+
ap = argparse.ArgumentParser()
|
| 66 |
+
# ----- mode -----
|
| 67 |
+
ap.add_argument("--mode", choices=["finetune", "test_only"], required=True)
|
| 68 |
+
ap.add_argument("--pretrained", default="/home/xinyi/Project/FD-VQA/src/checkpoints/kvq/qd_model.best.pt", help="pretrain model path")
|
| 69 |
+
# ----- data -----
|
| 70 |
+
ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/SHORTS-SDR-DATASET_metadata.csv")
|
| 71 |
+
ap.add_argument("--db_path", default="/media/xinyi/server/video_dataset/shorts-hdr-dataset/sdr/")
|
| 72 |
+
ap.add_argument("--split_seed", type=int, default=0)
|
| 73 |
+
ap.add_argument("--test_ratio", type=float, default=0.2)
|
| 74 |
+
ap.add_argument("--val_ratio", type=float, default=0.1) # train 80%
|
| 75 |
+
# ----- video processing -----
|
| 76 |
+
ap.add_argument("--clip_len", type=int, default=16)
|
| 77 |
+
ap.add_argument("--resize", type=int, default=224)
|
| 78 |
+
ap.add_argument("--win", type=int, default=6)
|
| 79 |
+
ap.add_argument("--win_step", type=int, default=1)
|
| 80 |
+
# ----- runtime -----
|
| 81 |
+
ap.add_argument("--batch_size", type=int, default=8)
|
| 82 |
+
ap.add_argument("--num_workers", type=int, default=4)
|
| 83 |
+
ap.add_argument("--device", type=str, default="cuda")
|
| 84 |
+
ap.add_argument("--no_amp", action="store_true")
|
| 85 |
+
# ----- finetune hyperparams -----
|
| 86 |
+
ap.add_argument("--epochs", type=int, default=10)
|
| 87 |
+
ap.add_argument("--warmup_epochs", type=int, default=1)
|
| 88 |
+
ap.add_argument("--lr", type=float, default=1e-5)
|
| 89 |
+
ap.add_argument("--min_lr", type=float, default=1e-6)
|
| 90 |
+
ap.add_argument("--finetune_lr", type=float, default=5e-5)
|
| 91 |
+
ap.add_argument("--weight_decay", type=float, default=1e-2)
|
| 92 |
+
ap.add_argument("--clip_unfreeze_blocks", type=int, default=4)
|
| 93 |
+
ap.add_argument("--finetune_last_stage", action="store_true")
|
| 94 |
+
ap.add_argument("--patience", type=int, default=6)
|
| 95 |
+
# ----- save -----
|
| 96 |
+
ap.add_argument("--save_dir", type=str, default="checkpoints_transfer")
|
| 97 |
+
ap.add_argument("--save_name", type=str, default="transfer.pt")
|
| 98 |
+
|
| 99 |
+
# ----- MOS normalization control -----
|
| 100 |
+
# finetune: use target train mean/std
|
| 101 |
+
ap.add_argument("--test_only_norm", choices=["none", "use_source_ckpt"], default="use_source_ckpt")
|
| 102 |
+
|
| 103 |
+
args = ap.parse_args()
|
| 104 |
+
torch.manual_seed(args.split_seed)
|
| 105 |
+
device = torch.device(args.device)
|
| 106 |
+
amp = not bool(args.no_amp)
|
| 107 |
+
|
| 108 |
+
# ----------------------------
|
| 109 |
+
# Load target rows and split
|
| 110 |
+
# ----------------------------
|
| 111 |
+
rows = read_vid_mos_csv(args.csv_path)
|
| 112 |
+
if args.mode == "finetune":
|
| 113 |
+
csv_path = Path(args.csv_path)
|
| 114 |
+
if csv_path.name == "KVQ_TRAIN_metadata.csv":
|
| 115 |
+
# KVQ challenge split
|
| 116 |
+
val_csv = csv_path.parent / "KVQ_VAL_metadata.csv"
|
| 117 |
+
test_csv = csv_path.parent / "KVQ_TEST_metadata.csv"
|
| 118 |
+
train_rows = read_vid_mos_csv(str(csv_path))
|
| 119 |
+
val_rows = read_vid_mos_csv(str(val_csv))
|
| 120 |
+
test_rows = read_vid_mos_csv(str(test_csv))
|
| 121 |
+
print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
|
| 122 |
+
else:
|
| 123 |
+
train_rows, val_rows, test_rows = split_rows(
|
| 124 |
+
rows, seed=args.split_seed, test_ratio=args.test_ratio, val_ratio=args.val_ratio
|
| 125 |
+
)
|
| 126 |
+
mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32)
|
| 127 |
+
mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0
|
| 128 |
+
mos_std = float(mos_train.std()) if len(mos_train) else 1.0
|
| 129 |
+
if mos_std <= 1e-8:
|
| 130 |
+
mos_std = 1.0
|
| 131 |
+
else:
|
| 132 |
+
# test_only
|
| 133 |
+
train_rows, val_rows, test_rows = [], [], rows
|
| 134 |
+
mos_mean, mos_std = None, None
|
| 135 |
+
|
| 136 |
+
# ----------------------------
|
| 137 |
+
# Build model, load pretrained
|
| 138 |
+
# ----------------------------
|
| 139 |
+
model = QD_MODEL(
|
| 140 |
+
clip_model="openai/clip-vit-base-patch16",
|
| 141 |
+
).to(device)
|
| 142 |
+
|
| 143 |
+
ckpt = load_pretrained_weights(model, args.pretrained, device)
|
| 144 |
+
print(f"Loaded pretrained: {args.pretrained}")
|
| 145 |
+
|
| 146 |
+
# ----------------------------
|
| 147 |
+
# Decide normalization for test_only
|
| 148 |
+
# ----------------------------
|
| 149 |
+
if args.mode == "test_only":
|
| 150 |
+
if args.test_only_norm == "use_source_ckpt" and isinstance(ckpt, dict):
|
| 151 |
+
# 注:这样对 SRCC/PLCC 没影响
|
| 152 |
+
mos_mean = ckpt.get("mos_mean", None)
|
| 153 |
+
mos_std = ckpt.get("mos_std", None)
|
| 154 |
+
if mos_mean is None or mos_std is None:
|
| 155 |
+
mos_mean, mos_std = None, None
|
| 156 |
+
print("[warn] pretrained ckpt has no mos_mean/std, fallback to no normalization.")
|
| 157 |
+
else:
|
| 158 |
+
print(f"test_only uses source mos_mean/std from ckpt: mean={mos_mean:.4f}, std={mos_std:.4f}")
|
| 159 |
+
else:
|
| 160 |
+
print("test_only uses no MOS normalization.")
|
| 161 |
+
|
| 162 |
+
# ----------------------------
|
| 163 |
+
# Mode: test_only
|
| 164 |
+
# ----------------------------
|
| 165 |
+
if args.mode == "test_only":
|
| 166 |
+
ds_test = VQADataset(
|
| 167 |
+
test_rows, args.db_path,
|
| 168 |
+
clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
|
| 169 |
+
mos_mean=mos_mean, mos_std=mos_std,
|
| 170 |
+
)
|
| 171 |
+
pin = str(device).startswith("cuda")
|
| 172 |
+
loader_test = torch.utils.data.DataLoader(
|
| 173 |
+
ds_test, batch_size=args.batch_size, shuffle=False,
|
| 174 |
+
num_workers=args.num_workers, pin_memory=pin, drop_last=False,
|
| 175 |
+
)
|
| 176 |
+
print("num_test_rows =", len(test_rows))
|
| 177 |
+
print("len(ds_test) =", len(ds_test))
|
| 178 |
+
print("len(loader_test) =", len(loader_test), "batch_size =", args.batch_size)
|
| 179 |
+
|
| 180 |
+
te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
|
| 181 |
+
model, loader_test, device,
|
| 182 |
+
optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
|
| 183 |
+
desc="TestOnly", show_pbar=True
|
| 184 |
+
)
|
| 185 |
+
print(f"TEST_ONLY | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
|
| 186 |
+
return
|
| 187 |
+
|
| 188 |
+
# ----------------------------
|
| 189 |
+
# DataLoaders
|
| 190 |
+
# ----------------------------
|
| 191 |
+
loader_train, loader_val, loader_test = make_loaders(
|
| 192 |
+
train_rows, val_rows, test_rows, args,
|
| 193 |
+
mos_mean=mos_mean, mos_std=mos_std, device=device
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# ----------------------------
|
| 197 |
+
# Mode: finetune
|
| 198 |
+
# ----------------------------
|
| 199 |
+
model.freeze_clip_all()
|
| 200 |
+
did_unfreeze = False
|
| 201 |
+
|
| 202 |
+
clip_params_all = []
|
| 203 |
+
other_params_all = []
|
| 204 |
+
for name, p in model.named_parameters():
|
| 205 |
+
if name.startswith("encoder."):
|
| 206 |
+
clip_params_all.append(p)
|
| 207 |
+
else:
|
| 208 |
+
other_params_all.append(p)
|
| 209 |
+
|
| 210 |
+
param_groups = []
|
| 211 |
+
if other_params_all:
|
| 212 |
+
param_groups.append({"params": other_params_all, "lr": float(args.lr)})
|
| 213 |
+
if clip_params_all:
|
| 214 |
+
param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)})
|
| 215 |
+
|
| 216 |
+
optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay))
|
| 217 |
+
scheduler = build_scheduler(optim, args)
|
| 218 |
+
|
| 219 |
+
save_dir = Path(args.save_dir)
|
| 220 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 221 |
+
last_path = save_dir / args.save_name
|
| 222 |
+
best_path = save_dir / (Path(args.save_name).stem + ".best.pt")
|
| 223 |
+
best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt")
|
| 224 |
+
|
| 225 |
+
best_val_srcc = -1e18
|
| 226 |
+
bad_epochs = 0
|
| 227 |
+
|
| 228 |
+
for epoch in range(1, int(args.epochs) + 1):
|
| 229 |
+
if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1):
|
| 230 |
+
model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True)
|
| 231 |
+
did_unfreeze = True
|
| 232 |
+
print(f"[Finetune] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks")
|
| 233 |
+
|
| 234 |
+
tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch(
|
| 235 |
+
model, loader_train, device,
|
| 236 |
+
optim=optim, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
|
| 237 |
+
desc=f"FT Train e{epoch:03d}", show_pbar=True
|
| 238 |
+
)
|
| 239 |
+
va_loss, va_plcc, va_srcc, va_rmse = run_epoch(
|
| 240 |
+
model, loader_val, device,
|
| 241 |
+
optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
|
| 242 |
+
desc=f"FT Val e{epoch:03d}", show_pbar=True
|
| 243 |
+
)
|
| 244 |
+
scheduler.step()
|
| 245 |
+
|
| 246 |
+
print(
|
| 247 |
+
f"epoch {epoch:03d} | "
|
| 248 |
+
f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | "
|
| 249 |
+
f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
ckpt_out = {
|
| 253 |
+
"epoch": epoch,
|
| 254 |
+
"model": model.state_dict(),
|
| 255 |
+
"optim": optim.state_dict(),
|
| 256 |
+
"mos_mean": mos_mean,
|
| 257 |
+
"mos_std": mos_std,
|
| 258 |
+
"args": vars(args),
|
| 259 |
+
"best_val_srcc": best_val_srcc,
|
| 260 |
+
}
|
| 261 |
+
torch.save(ckpt_out, str(last_path))
|
| 262 |
+
|
| 263 |
+
if va_srcc > best_val_srcc:
|
| 264 |
+
best_val_srcc = va_srcc
|
| 265 |
+
bad_epochs = 0
|
| 266 |
+
ckpt_out["best_val_srcc"] = best_val_srcc
|
| 267 |
+
torch.save(ckpt_out, str(best_path))
|
| 268 |
+
torch.save(model.state_dict(), str(best_weights_path))
|
| 269 |
+
print(f" [best] val_srcc={best_val_srcc:.4f} -> saved {best_weights_path}")
|
| 270 |
+
else:
|
| 271 |
+
bad_epochs += 1
|
| 272 |
+
if bad_epochs >= int(args.patience):
|
| 273 |
+
print(f"[EarlyStop] val_srcc not improved for {bad_epochs} epochs. Stop.")
|
| 274 |
+
break
|
| 275 |
+
|
| 276 |
+
# load best and test
|
| 277 |
+
if best_weights_path.exists():
|
| 278 |
+
sd = torch.load(str(best_weights_path), map_location=device, weights_only=True)
|
| 279 |
+
model.load_state_dict(sd, strict=True)
|
| 280 |
+
print(f"Loaded best weights: {best_weights_path}")
|
| 281 |
+
|
| 282 |
+
te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
|
| 283 |
+
model, loader_test, device,
|
| 284 |
+
optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
|
| 285 |
+
desc="FT Test", show_pbar=True
|
| 286 |
+
)
|
| 287 |
+
print(f"FINETUNE TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
main()
|
src/transfer_test_only.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import torch
|
| 7 |
+
from torch.amp import autocast
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from train import VQADataset, com_loss, pearsonr, read_vid_mos_csv, spearmanr
|
| 11 |
+
from model.qd_model import QD_MODEL
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_checkpoint(ckpt_path, device):
|
| 15 |
+
ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=True)
|
| 16 |
+
if isinstance(ckpt, dict) and "model" in ckpt:
|
| 17 |
+
return {
|
| 18 |
+
"state_dict": ckpt["model"],
|
| 19 |
+
"train_mos_mean": ckpt.get("mos_mean"),
|
| 20 |
+
"train_mos_std": ckpt.get("mos_std"),
|
| 21 |
+
"train_args": ckpt.get("args", {}),
|
| 22 |
+
"is_full_checkpoint": True,
|
| 23 |
+
}
|
| 24 |
+
if isinstance(ckpt, dict):
|
| 25 |
+
return {
|
| 26 |
+
"state_dict": ckpt,
|
| 27 |
+
"train_mos_mean": None,
|
| 28 |
+
"train_mos_std": None,
|
| 29 |
+
"train_args": {},
|
| 30 |
+
"is_full_checkpoint": False,
|
| 31 |
+
}
|
| 32 |
+
raise TypeError(f"Unsupported checkpoint type: {type(ckpt)!r}")
|
| 33 |
+
|
| 34 |
+
def infer_test_scale(rows):
|
| 35 |
+
mos_values = [float(mos) for _vid, mos in rows]
|
| 36 |
+
if not mos_values:
|
| 37 |
+
raise ValueError("Cannot infer test scale from empty rows")
|
| 38 |
+
|
| 39 |
+
lo = min(mos_values)
|
| 40 |
+
hi = max(mos_values)
|
| 41 |
+
|
| 42 |
+
if 0.0 <= lo and hi <= 1.0:
|
| 43 |
+
return 0.0, 1.0
|
| 44 |
+
if 1.0 <= lo and hi <= 5.0:
|
| 45 |
+
return 1.0, 5.0
|
| 46 |
+
if 0.0 <= lo and hi <= 5.0:
|
| 47 |
+
return 0.0, 5.0
|
| 48 |
+
return 0.0, 100.0
|
| 49 |
+
|
| 50 |
+
def linear_remap(x, src_min, src_max, dst_min, dst_max):
|
| 51 |
+
src_min = float(src_min)
|
| 52 |
+
src_max = float(src_max)
|
| 53 |
+
dst_min = float(dst_min)
|
| 54 |
+
dst_max = float(dst_max)
|
| 55 |
+
|
| 56 |
+
if abs(src_max - src_min) <= 1e-12:
|
| 57 |
+
raise ValueError("Source scale range must be non-zero")
|
| 58 |
+
|
| 59 |
+
return (x - src_min) / (src_max - src_min) * (dst_max - dst_min) + dst_min
|
| 60 |
+
|
| 61 |
+
def save_predictions_csv(save_path, vids, y_true_raw, pred_train_scale, pred_eval_scale):
|
| 62 |
+
save_path = Path(save_path)
|
| 63 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
with open(save_path, "w", newline="", encoding="utf-8") as f:
|
| 66 |
+
writer = csv.writer(f)
|
| 67 |
+
writer.writerow(["vid", "y_true_raw", "pred_train_scale", "pred_eval_scale"])
|
| 68 |
+
for vid, y_true, pred_train, pred_eval in zip(
|
| 69 |
+
vids,
|
| 70 |
+
y_true_raw.tolist(),
|
| 71 |
+
pred_train_scale.tolist(),
|
| 72 |
+
pred_eval_scale.tolist(),
|
| 73 |
+
strict=False,
|
| 74 |
+
):
|
| 75 |
+
writer.writerow([vid, float(y_true), float(pred_train), float(pred_eval)])
|
| 76 |
+
|
| 77 |
+
return save_path
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def evaluate_and_collect(
|
| 81 |
+
model,
|
| 82 |
+
loader,
|
| 83 |
+
device,
|
| 84 |
+
*,
|
| 85 |
+
amp=True,
|
| 86 |
+
train_mos_mean,
|
| 87 |
+
train_mos_std,
|
| 88 |
+
train_scale_min,
|
| 89 |
+
train_scale_max,
|
| 90 |
+
test_scale_min,
|
| 91 |
+
test_scale_max,
|
| 92 |
+
desc="",
|
| 93 |
+
show_pbar=True,
|
| 94 |
+
log_interval=10,
|
| 95 |
+
):
|
| 96 |
+
model.eval()
|
| 97 |
+
|
| 98 |
+
losses = []
|
| 99 |
+
y_all = []
|
| 100 |
+
yhat_all = []
|
| 101 |
+
vids_all = []
|
| 102 |
+
|
| 103 |
+
it = loader
|
| 104 |
+
if show_pbar:
|
| 105 |
+
it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True)
|
| 106 |
+
|
| 107 |
+
for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1):
|
| 108 |
+
rgb = rgb.to(device, non_blocking=True)
|
| 109 |
+
w_art = w_art.to(device, non_blocking=True)
|
| 110 |
+
w_str = w_str.to(device, non_blocking=True)
|
| 111 |
+
y = y.to(device, non_blocking=True).float()
|
| 112 |
+
|
| 113 |
+
device_type = "cuda" if str(device).startswith("cuda") else "cpu"
|
| 114 |
+
with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
|
| 115 |
+
yhat, _aux = model(rgb, w_art, w_str)
|
| 116 |
+
loss, _loss_reg, _loss_rank = com_loss(yhat, y)
|
| 117 |
+
|
| 118 |
+
losses.append(loss.detach().float().cpu())
|
| 119 |
+
y_all.append(y.detach().float().cpu())
|
| 120 |
+
yhat_all.append(yhat.detach().float().cpu())
|
| 121 |
+
vids_all.extend(list(vid))
|
| 122 |
+
|
| 123 |
+
if show_pbar and (step % int(log_interval) == 0 or step == len(loader)):
|
| 124 |
+
avg_loss_so_far = torch.stack(losses).mean().item()
|
| 125 |
+
it.set_postfix({"loss": f"{avg_loss_so_far:.4f}"})
|
| 126 |
+
|
| 127 |
+
if y_all:
|
| 128 |
+
y_all = torch.cat(y_all, dim=0)
|
| 129 |
+
yhat_all = torch.cat(yhat_all, dim=0)
|
| 130 |
+
else:
|
| 131 |
+
y_all = torch.empty(0)
|
| 132 |
+
yhat_all = torch.empty(0)
|
| 133 |
+
|
| 134 |
+
y_true_raw = y_all * float(train_mos_std) + float(train_mos_mean)
|
| 135 |
+
pred_train_scale = yhat_all * float(train_mos_std) + float(train_mos_mean)
|
| 136 |
+
pred_eval_scale = linear_remap(
|
| 137 |
+
pred_train_scale,
|
| 138 |
+
src_min=float(train_scale_min),
|
| 139 |
+
src_max=float(train_scale_max),
|
| 140 |
+
dst_min=float(test_scale_min),
|
| 141 |
+
dst_max=float(test_scale_max),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
plcc = pearsonr(y_true_raw, pred_eval_scale).item() if y_true_raw.numel() > 1 else 0.0
|
| 145 |
+
srcc = spearmanr(y_true_raw, pred_eval_scale).item() if y_true_raw.numel() > 1 else 0.0
|
| 146 |
+
rmse = (
|
| 147 |
+
torch.sqrt(torch.mean((pred_eval_scale - y_true_raw) ** 2)).item()
|
| 148 |
+
if y_true_raw.numel() > 0
|
| 149 |
+
else 0.0
|
| 150 |
+
)
|
| 151 |
+
avg_loss = torch.stack(losses).mean().item() if losses else 0.0
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
"loss": avg_loss,
|
| 155 |
+
"plcc": plcc,
|
| 156 |
+
"srcc": srcc,
|
| 157 |
+
"rmse": rmse,
|
| 158 |
+
"vids": vids_all,
|
| 159 |
+
"y_true_raw": y_true_raw,
|
| 160 |
+
"pred_train_scale": pred_train_scale,
|
| 161 |
+
"pred_eval_scale": pred_eval_scale,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def main():
|
| 165 |
+
ap = argparse.ArgumentParser()
|
| 166 |
+
ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/lsvq/qd_model.best.pt")
|
| 167 |
+
ap.add_argument("--csv_path", type=str, default="/home/xinyi/Project/FD-VQA/metadata/KVQ_metadata.csv")
|
| 168 |
+
ap.add_argument("--db_path", type=str, default="/media/xinyi/server/video_dataset/KVQ")
|
| 169 |
+
|
| 170 |
+
ap.add_argument("--clip_len", type=int, default=16)
|
| 171 |
+
ap.add_argument("--resize", type=int, default=224)
|
| 172 |
+
ap.add_argument("--win", type=int, default=6)
|
| 173 |
+
ap.add_argument("--win_step", type=int, default=1)
|
| 174 |
+
|
| 175 |
+
ap.add_argument("--batch_size", type=int, default=8)
|
| 176 |
+
ap.add_argument("--num_workers", type=int, default=2)
|
| 177 |
+
ap.add_argument("--device", type=str, default="cuda")
|
| 178 |
+
ap.add_argument("--no_amp", action="store_true")
|
| 179 |
+
|
| 180 |
+
ap.add_argument("--train_scale_min", type=float, default=0.0)
|
| 181 |
+
ap.add_argument("--train_scale_max", type=float, default=100.0)
|
| 182 |
+
ap.add_argument("--test_scale_min", type=float, default=1.0)
|
| 183 |
+
ap.add_argument("--test_scale_max", type=float, default=5.0)
|
| 184 |
+
|
| 185 |
+
ap.add_argument("--save_pred_csv", type=str, default="/home/xinyi/Project/FD-VQA/src/transfer_test/transfer_test_only_konvid_1k.csv")
|
| 186 |
+
args = ap.parse_args()
|
| 187 |
+
|
| 188 |
+
device = torch.device(args.device)
|
| 189 |
+
amp = not bool(args.no_amp)
|
| 190 |
+
ckpt_info = load_checkpoint(Path(args.ckpt_path), device)
|
| 191 |
+
|
| 192 |
+
train_mos_mean = ckpt_info["train_mos_mean"]
|
| 193 |
+
train_mos_std = ckpt_info["train_mos_std"]
|
| 194 |
+
if train_mos_mean is None or train_mos_std is None:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
"Prefer loading *.best.pt / *.pt, or pass --train_mos_mean and --train_mos_std manually."
|
| 197 |
+
)
|
| 198 |
+
if float(train_mos_std) <= 1e-8:
|
| 199 |
+
raise ValueError("train_mos_std must be > 0")
|
| 200 |
+
|
| 201 |
+
rows = read_vid_mos_csv(args.csv_path)
|
| 202 |
+
if not rows:
|
| 203 |
+
raise ValueError(f"No rows found in csv: {args.csv_path}")
|
| 204 |
+
|
| 205 |
+
if args.test_scale_min is None or args.test_scale_max is None:
|
| 206 |
+
inferred_test_scale_min, inferred_test_scale_max = infer_test_scale(rows)
|
| 207 |
+
test_scale_min = inferred_test_scale_min
|
| 208 |
+
test_scale_max = inferred_test_scale_max
|
| 209 |
+
else:
|
| 210 |
+
test_scale_min = float(args.test_scale_min)
|
| 211 |
+
test_scale_max = float(args.test_scale_max)
|
| 212 |
+
|
| 213 |
+
dataset = VQADataset(
|
| 214 |
+
rows,
|
| 215 |
+
args.db_path,
|
| 216 |
+
clip_len=args.clip_len,
|
| 217 |
+
size=args.resize,
|
| 218 |
+
win=args.win,
|
| 219 |
+
win_step=args.win_step,
|
| 220 |
+
mos_mean=float(train_mos_mean),
|
| 221 |
+
mos_std=float(train_mos_std),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
pin = str(device).startswith("cuda")
|
| 225 |
+
loader = torch.utils.data.DataLoader(
|
| 226 |
+
dataset,
|
| 227 |
+
batch_size=int(args.batch_size),
|
| 228 |
+
shuffle=False,
|
| 229 |
+
num_workers=int(args.num_workers),
|
| 230 |
+
pin_memory=pin,
|
| 231 |
+
drop_last=False,
|
| 232 |
+
prefetch_factor=4 if int(args.num_workers) > 0 else None,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
model = QD_MODEL(
|
| 236 |
+
clip_model="openai/clip-vit-base-patch16",
|
| 237 |
+
).to(device)
|
| 238 |
+
model.load_state_dict(ckpt_info["state_dict"], strict=True)
|
| 239 |
+
|
| 240 |
+
print(f"Loaded checkpoint: {args.ckpt_path}")
|
| 241 |
+
print(f"Training normalization: mean={float(train_mos_mean):.6f}, std={float(train_mos_std):.6f}")
|
| 242 |
+
print(
|
| 243 |
+
f"Scale mapping: train=[{float(args.train_scale_min):.3f}, {float(args.train_scale_max):.3f}] -> "
|
| 244 |
+
f"test=[{float(test_scale_min):.3f}, {float(test_scale_max):.3f}]"
|
| 245 |
+
)
|
| 246 |
+
print(f"Test rows: {len(rows)}")
|
| 247 |
+
|
| 248 |
+
metrics = evaluate_and_collect(
|
| 249 |
+
model,
|
| 250 |
+
loader,
|
| 251 |
+
device,
|
| 252 |
+
amp=amp,
|
| 253 |
+
train_mos_mean=float(train_mos_mean),
|
| 254 |
+
train_mos_std=float(train_mos_std),
|
| 255 |
+
train_scale_min=float(args.train_scale_min),
|
| 256 |
+
train_scale_max=float(args.train_scale_max),
|
| 257 |
+
test_scale_min=float(test_scale_min),
|
| 258 |
+
test_scale_max=float(test_scale_max),
|
| 259 |
+
desc="Cross-dataset test",
|
| 260 |
+
show_pbar=True,
|
| 261 |
+
log_interval=10,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
print(
|
| 265 |
+
"TEST | "
|
| 266 |
+
f"loss={metrics['loss']:.4f} "
|
| 267 |
+
f"plcc={metrics['plcc']:.4f} "
|
| 268 |
+
f"srcc={metrics['srcc']:.4f} "
|
| 269 |
+
f"rmse={metrics['rmse']:.4f}"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if args.save_pred_csv:
|
| 273 |
+
save_path = save_predictions_csv(
|
| 274 |
+
args.save_pred_csv,
|
| 275 |
+
metrics["vids"],
|
| 276 |
+
metrics["y_true_raw"],
|
| 277 |
+
metrics["pred_train_scale"],
|
| 278 |
+
metrics["pred_eval_scale"],
|
| 279 |
+
)
|
| 280 |
+
print(f"Saved predictions to: {save_path}")
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
if __name__ == "__main__":
|
| 284 |
+
main()
|