Xinyi Wang commited on
Commit
db25ead
·
0 Parent(s):

project files

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .gitignore
2
+ .DS_Store
3
+ test_videos/
README.md ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FGSVQA
2
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=xinyiW915/FGSVQA)
3
+ ![GitHub Repo stars](https://img.shields.io/github/stars/xinyiW915/FGSVQA?logo=github)
4
+ ![Python](https://img.shields.io/badge/Python-3.8+-blue)
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2605.20016-b31b1b.svg)](http://arxiv.org/abs/2605.20016)
6
+
7
+ Official Code for the following paper:
8
+
9
+ **X. Wang, A. Katsenou, J.Shen and D. Bull**. [FGSVQA: Frequency-Guided Short-form Video Quality Assessment](http://arxiv.org/abs/2605.20016)
10
+
11
+ [Our paper]() was accepted by the 18th International Conference on Quality of Multimedia Experience ([QoMEX 2026](https://qomex2026.itec.aau.at/)).
12
+
13
+ ---
14
+
15
+ ## Performance
16
+ We validated our proposed method on two publicly available Short-form UGC datasets: KVQ and YouTube SFV+HDR dataset (YT-SFV).
17
+
18
+ #### **Spearman’s Rank Correlation Coefficient (SRCC)**
19
+ | **Model** | **KVQ** | **YT-SFV (SDR)** | **YT-SFV (HDR2SDR)** |
20
+ |----------------------------|-----------|------------------|----------------------|
21
+ | FGSVQA | 0.877 | 0.788 | 0.543 |
22
+
23
+ #### **Pearson’s Linear Correlation Coefficient (PLCC)**
24
+ | **Model** | **KVQ** | **YT-SFV (SDR)** | **YT-SFV (HDR2SDR)** |
25
+ |----------------------------|-----------|------------------|----------------------|
26
+ | FGSVQA | 0.878 | 0.818 | 0.666 |
27
+
28
+ #### **GPU runtime comparison (averaged over 10 runs) across different spatial resolutions on "SDR\_Animal\_5ngj.mp4".**
29
+ | Method | Time(s)<br>540P | Time(s)<br>720P | Time(s)<br>1080P | Time(s)<br>2160P | Ground truth: 4.308<br>Predicted Score|
30
+ |---|------------:|------------:|-------------:|---:|---:|
31
+ | Fast-VQA | 0.599 | 0.673 | 0.909 | 2.217 | 3.319 |
32
+ | FasterVQA | 0.489 | 0.547 | **0.696** | **1.343** | 3.556 |
33
+ | DOVER | 0.920 | 1.022 | 1.293 | 2.783 | 3.814 |
34
+ | FGSVQA | **0.313** | **0.405** | 0.697 | 2.137 | **3.878** |
35
+
36
+ More results can be found in **[correlation_result.ipynb](https://github.com/xinyiW915/FGSVQA/blob/main/src/correlation_result.ipynb)**.
37
+
38
+ ## Proposed Model
39
+ Overview of the proposed model with the two branches: the frequency-guided weight map and the CLIP vision encoder.
40
+
41
+ <img src="./SVQA.png" alt="proposed_FGSVQA_framework" width="800"/>
42
+
43
+ ## Usage
44
+ ### 📌 Install Requirement
45
+ The repository is built with **Python 3.10** and can be installed via the following commands:
46
+
47
+ ```shell
48
+ git clone https://github.com/xinyiW915/FGSVQA.git
49
+ cd FGSVQA
50
+ conda create -n fgsvqa python=3.10 -y
51
+ conda activate fgsvqa
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+ ### 📥 Download UGC Datasets
56
+ The corresponding UGC video datasets can be downloaded from the following sources:
57
+ [KVQ](https://lixinustc.github.io/projects/KVQ/), [YouTube SFV+HDR](https://media.withyoutube.com/sfv-hdr).
58
+
59
+ The metadata for the experimented UGC dataset is available under [`./metadata`](./metadata).
60
+
61
+ ### 🎬 Test Demo
62
+ Run the pre-trained model to evaluate the perceptual quality of a single video. The demo script reports the predicted quality score, runtime, and model complexity.
63
+
64
+ The model checkpoint should be provided through `--ckpt_path`. Please use a full checkpoint file, such as `qd_model.best.pt`, which contains the saved model weights together with the training MOS mean and standard deviation.
65
+
66
+ To evaluate a single video, run:
67
+ ```shell
68
+ python demo_test.py \
69
+ --ckpt_path <MODEL_PATH> \
70
+ --db_path <VIDEO_FOLDER> \
71
+ --video_id <VIDEO_ID> \
72
+ --device <DEVICE>
73
+ ````
74
+ For example:
75
+ ```shell
76
+ python demo_test.py \
77
+ --ckpt_path ./checkpoints/lsvq/qd_model.best.pt \
78
+ --db_path ./test_videos/ \
79
+ --video_id SDR_Animal_5ngj \
80
+ --device cuda
81
+ ```
82
+
83
+ ### 🔁 Cross-Dataset Evaluation
84
+ To evaluate a trained model on another dataset, use `transfer_test_only.py`. This script loads a trained checkpoint, reports the evaluation metrics, and saves the prediction results to a CSV file.
85
+
86
+ Run:
87
+ ```shell
88
+ python transfer_test_only.py \
89
+ --ckpt_path <MODEL_PATH> \
90
+ --csv_path <TEST_METADATA_CSV> \
91
+ --db_path <TEST_VIDEO_FOLDER> \
92
+ --device <DEVICE> \
93
+ --save_pred_csv <SAVE_PREDICTION_CSV>
94
+ ```
95
+ For example:
96
+ ```shell
97
+ python transfer_test_only.py \
98
+ --ckpt_path ./checkpoints/lsvq/qd_model.best.pt \
99
+ --csv_path ./metadata/KVQ_metadata.csv \
100
+ --db_path /path/to/KVQ/videos \
101
+ --device cuda \
102
+ --save_pred_csv /path/to/transfer_test_only_konvid_1k.csv
103
+ ```
104
+
105
+ ## Training
106
+ Steps to train and fine-tune the model on different datasets.
107
+
108
+ ### Train Model
109
+ Train the model using the metadata CSV file and the corresponding video folder. The metadata CSV file should contain `vid` and `mos` columns.
110
+
111
+ ```shell
112
+ python train.py \
113
+ --csv_path <TRAIN_METADATA_CSV> \
114
+ --db_path <VIDEO_FOLDER> \
115
+ --save_dir <SAVE_DIR> \
116
+ --save_name qd_model.pt \
117
+ --device <DEVICE> \
118
+ --finetune_last_stage
119
+ ```
120
+ For example:
121
+ ```shell
122
+ python train.py \
123
+ --csv_path ./metadata/KVQ_TRAIN_metadata.csv \
124
+ --db_path /path/to/KVQ/videos \
125
+ --save_dir ./checkpoints/kvq \
126
+ --save_name qd_model.pt \
127
+ --device cuda \
128
+ --finetune_last_stage
129
+ ```
130
+ The script saves the latest checkpoint and the best-performing checkpoint according to the validation SRCC.
131
+
132
+ ### Transfer Model
133
+ To fine-tune a pre-trained model on a new dataset, run:
134
+
135
+ ```shell
136
+ python transfer.py \
137
+ --mode finetune \
138
+ --pretrained <PRETRAINED_MODEL_PATH> \
139
+ --csv_path <TARGET_METADATA_CSV> \
140
+ --db_path <TARGET_VIDEO_FOLDER> \
141
+ --save_dir <SAVE_DIR> \
142
+ --save_name transfer.pt \
143
+ --device <DEVICE> \
144
+ --finetune_last_stage
145
+ ```
146
+ For example:
147
+ ```shell
148
+ python transfer.py \
149
+ --mode finetune \
150
+ --pretrained ./checkpoints/shorts-hdr-dataset_sdr/qd_model.best.pt \
151
+ --csv_path ./metadata/KVQ_TRAIN_metadata.csv \
152
+ --db_path /path/to/KVQ/videos \
153
+ --save_dir ./checkpoints_transfer/kvq \
154
+ --save_name transfer.pt \
155
+ --device cuda \
156
+ --finetune_last_stage
157
+ ```
158
+
159
+ ### Test Only
160
+ To directly test a pre-trained model on another dataset, run:
161
+
162
+ ```shell
163
+ python transfer.py \
164
+ --mode test_only \
165
+ --pretrained <PRETRAINED_MODEL_PATH> \
166
+ --csv_path <TEST_METADATA_CSV> \
167
+ --db_path <TEST_VIDEO_FOLDER> \
168
+ --device <DEVICE>
169
+ ```
170
+ For example:
171
+ ```shell
172
+ python transfer.py \
173
+ --mode test_only \
174
+ --pretrained ./checkpoints/shorts-hdr-dataset_sdr/qd_model.best.pt \
175
+ --csv_path ./metadata/KVQ_metadata.csv \
176
+ --db_path /path/to/KVQ/videos \
177
+ --device cuda
178
+ ```
179
+
180
+
181
+ ## Acknowledgment
182
+ This work was funded by the UKRI MyWorld Strength in Places Programme (SIPF00006/1) as part of my PhD study.
183
+
184
+ ## Citation
185
+ If you find this paper and the repo useful, please cite our paper 😊:
186
+
187
+ ```bibtex
188
+ @article{wang2026fgsvqa,
189
+ title={FGSVQA: Frequency-Guided Short-form Video Quality Assessment},
190
+ author={Wang, Xinyi and Katsenou, Angeliki, Shen, Junxiao and Bull, David},
191
+ booktitle={2026 18th International Conference on Quality of Multimedia Experience (QoMEX)},
192
+ year={2026},
193
+ organization={IEEE}
194
+ }
195
+ ```
196
+
197
+ ## Contact:
198
+ Xinyi WANG, ```xinyi.wang@bristol.ac.uk```
metadata/KVQ_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
metadata/LSVQ_TEST_1080P_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
metadata/LSVQ_TEST_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
metadata/LSVQ_TRAIN_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
metadata/SHORTS-HDR2SDR-DATASET_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
metadata/SHORTS-SDR-DATASET_metadata.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.12.0
2
+ annotated-doc==0.0.4
3
+ anyio==4.12.1
4
+ av==16.1.0
5
+ bitsandbytes==0.49.1
6
+ certifi==2026.1.4
7
+ click==8.3.1
8
+ contourpy==1.3.2
9
+ cuda-bindings==12.9.4
10
+ cuda-pathfinder==1.3.3
11
+ cycler==0.12.1
12
+ decord==0.6.0
13
+ einops==0.8.2
14
+ exceptiongroup==1.3.1
15
+ filelock==3.20.0
16
+ fonttools==4.61.1
17
+ fsspec==2025.12.0
18
+ h11==0.16.0
19
+ hf-xet==1.2.0
20
+ httpcore==1.0.9
21
+ httpx==0.28.1
22
+ huggingface_hub==1.4.1
23
+ idna==3.11
24
+ Jinja2==3.1.6
25
+ kiwisolver==1.4.9
26
+ mamba-ssm==2.3.0
27
+ markdown-it-py==4.0.0
28
+ MarkupSafe==2.1.5
29
+ matplotlib==3.10.8
30
+ mdurl==0.1.2
31
+ mpmath==1.3.0
32
+ networkx==3.4.2
33
+ ninja==1.13.0
34
+ numpy==2.2.6
35
+ nvidia-cublas-cu12==12.1.3.1
36
+ nvidia-cuda-cupti-cu12==12.1.105
37
+ nvidia-cuda-nvrtc-cu12==12.1.105
38
+ nvidia-cuda-runtime-cu12==12.1.105
39
+ nvidia-cudnn-cu12==9.1.0.70
40
+ nvidia-cufft-cu12==11.0.2.54
41
+ nvidia-cufile-cu12==1.13.1.3
42
+ nvidia-curand-cu12==10.3.2.106
43
+ nvidia-cusolver-cu12==11.4.5.107
44
+ nvidia-cusparse-cu12==12.1.0.106
45
+ nvidia-cusparselt-cu12==0.7.1
46
+ nvidia-nccl-cu12==2.21.5
47
+ nvidia-nvjitlink-cu12==12.8.93
48
+ nvidia-nvshmem-cu12==3.4.5
49
+ nvidia-nvtx-cu12==12.1.105
50
+ opencv-python==4.13.0.92
51
+ packaging @ file:///home/task_176104877067765/conda-bld/packaging_1761049113113/work
52
+ pillow==12.1.1
53
+ psutil==7.2.2
54
+ Pygments==2.19.2
55
+ pyparsing==3.3.2
56
+ python-dateutil==2.9.0.post0
57
+ PyYAML==6.0.3
58
+ regex==2026.1.15
59
+ rich==14.3.2
60
+ safetensors==0.7.0
61
+ scipy==1.15.3
62
+ shellingham==1.5.4
63
+ six==1.17.0
64
+ sympy==1.13.1
65
+ tokenizers==0.22.2
66
+ torch==2.5.1+cu121
67
+ torchaudio==2.5.1+cu121
68
+ torchvision==0.20.1+cu121
69
+ tqdm==4.67.3
70
+ transformers==5.1.0
71
+ triton==3.1.0
72
+ typer==0.23.0
73
+ typer-slim==0.23.0
74
+ typing_extensions==4.15.0
75
+ xformers==0.0.34
src/correlation_result.ipynb ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "initial_id",
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "ExecuteTime": {
9
+ "end_time": "2026-05-19T12:42:02.280702Z",
10
+ "start_time": "2026-05-19T12:42:02.276769Z"
11
+ }
12
+ },
13
+ "source": [
14
+ "from pathlib import Path\n",
15
+ "import re\n",
16
+ "import pandas as pd"
17
+ ],
18
+ "outputs": [],
19
+ "execution_count": 19
20
+ },
21
+ {
22
+ "metadata": {},
23
+ "cell_type": "markdown",
24
+ "source": "Intra-dataset evaluation: Train on and test on the same <dataset>",
25
+ "id": "145ddc0eb7f59e45"
26
+ },
27
+ {
28
+ "metadata": {
29
+ "ExecuteTime": {
30
+ "end_time": "2026-05-19T12:42:02.291070Z",
31
+ "start_time": "2026-05-19T12:42:02.288978Z"
32
+ }
33
+ },
34
+ "cell_type": "code",
35
+ "source": "LOG_ROOT = Path(\"./checkpoints\")",
36
+ "id": "eba6249c8ad34da8",
37
+ "outputs": [],
38
+ "execution_count": 20
39
+ },
40
+ {
41
+ "metadata": {
42
+ "ExecuteTime": {
43
+ "end_time": "2026-05-19T12:42:02.315625Z",
44
+ "start_time": "2026-05-19T12:42:02.302579Z"
45
+ }
46
+ },
47
+ "cell_type": "code",
48
+ "source": [
49
+ "filename_pattern = re.compile(r\"^train_(?P<dataset>.+)\\.log$\")\n",
50
+ "result_pattern = re.compile(\n",
51
+ " r\"TEST\\s*\\|\\s*\"\n",
52
+ " r\"loss=(?P<loss>[-+]?\\d*\\.?\\d+)\\s+\"\n",
53
+ " r\"plcc=(?P<plcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
54
+ " r\"srcc=(?P<srcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
55
+ " r\"rmse=(?P<rmse>[-+]?\\d*\\.?\\d+)\"\n",
56
+ ")\n",
57
+ "records = []\n",
58
+ "for log_path in LOG_ROOT.rglob(\"*.log\"):\n",
59
+ " filename_match = filename_pattern.match(log_path.name)\n",
60
+ " dataset = filename_match.group(\"dataset\")\n",
61
+ "\n",
62
+ " text = log_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n",
63
+ " matches = list(result_pattern.finditer(text))\n",
64
+ " if not matches:\n",
65
+ " continue\n",
66
+ " records.append({\n",
67
+ " \"dataset\": dataset,\n",
68
+ " \"srcc\": float(matches[-1].group(\"srcc\")),\n",
69
+ " \"plcc\": float(matches[-1].group(\"plcc\")),\n",
70
+ " })\n",
71
+ "train_df = (pd.DataFrame(records).sort_values(\"dataset\", ignore_index=True)[[\"dataset\", \"srcc\", \"plcc\"]])\n",
72
+ "train_df"
73
+ ],
74
+ "id": "4f2532e62a94f758",
75
+ "outputs": [
76
+ {
77
+ "data": {
78
+ "text/plain": [
79
+ " dataset srcc plcc\n",
80
+ "0 finevd 0.8596 0.8579\n",
81
+ "1 konvid_1k 0.7970 0.8171\n",
82
+ "2 kvq 0.8768 0.8781\n",
83
+ "3 live_vqc 0.6386 0.7192\n",
84
+ "4 live_yt_gaming 0.8554 0.8755\n",
85
+ "5 lsvq 0.8753 0.8760\n",
86
+ "6 shorts-hdr-dataset_hdr2sdr 0.5431 0.6657\n",
87
+ "7 shorts-hdr-dataset_sdr 0.7884 0.8182\n",
88
+ "8 youtube_ugc 0.8325 0.8548"
89
+ ],
90
+ "text/html": [
91
+ "<div>\n",
92
+ "<style scoped>\n",
93
+ " .dataframe tbody tr th:only-of-type {\n",
94
+ " vertical-align: middle;\n",
95
+ " }\n",
96
+ "\n",
97
+ " .dataframe tbody tr th {\n",
98
+ " vertical-align: top;\n",
99
+ " }\n",
100
+ "\n",
101
+ " .dataframe thead th {\n",
102
+ " text-align: right;\n",
103
+ " }\n",
104
+ "</style>\n",
105
+ "<table border=\"1\" class=\"dataframe\">\n",
106
+ " <thead>\n",
107
+ " <tr style=\"text-align: right;\">\n",
108
+ " <th></th>\n",
109
+ " <th>dataset</th>\n",
110
+ " <th>srcc</th>\n",
111
+ " <th>plcc</th>\n",
112
+ " </tr>\n",
113
+ " </thead>\n",
114
+ " <tbody>\n",
115
+ " <tr>\n",
116
+ " <th>0</th>\n",
117
+ " <td>finevd</td>\n",
118
+ " <td>0.8596</td>\n",
119
+ " <td>0.8579</td>\n",
120
+ " </tr>\n",
121
+ " <tr>\n",
122
+ " <th>1</th>\n",
123
+ " <td>konvid_1k</td>\n",
124
+ " <td>0.7970</td>\n",
125
+ " <td>0.8171</td>\n",
126
+ " </tr>\n",
127
+ " <tr>\n",
128
+ " <th>2</th>\n",
129
+ " <td>kvq</td>\n",
130
+ " <td>0.8768</td>\n",
131
+ " <td>0.8781</td>\n",
132
+ " </tr>\n",
133
+ " <tr>\n",
134
+ " <th>3</th>\n",
135
+ " <td>live_vqc</td>\n",
136
+ " <td>0.6386</td>\n",
137
+ " <td>0.7192</td>\n",
138
+ " </tr>\n",
139
+ " <tr>\n",
140
+ " <th>4</th>\n",
141
+ " <td>live_yt_gaming</td>\n",
142
+ " <td>0.8554</td>\n",
143
+ " <td>0.8755</td>\n",
144
+ " </tr>\n",
145
+ " <tr>\n",
146
+ " <th>5</th>\n",
147
+ " <td>lsvq</td>\n",
148
+ " <td>0.8753</td>\n",
149
+ " <td>0.8760</td>\n",
150
+ " </tr>\n",
151
+ " <tr>\n",
152
+ " <th>6</th>\n",
153
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
154
+ " <td>0.5431</td>\n",
155
+ " <td>0.6657</td>\n",
156
+ " </tr>\n",
157
+ " <tr>\n",
158
+ " <th>7</th>\n",
159
+ " <td>shorts-hdr-dataset_sdr</td>\n",
160
+ " <td>0.7884</td>\n",
161
+ " <td>0.8182</td>\n",
162
+ " </tr>\n",
163
+ " <tr>\n",
164
+ " <th>8</th>\n",
165
+ " <td>youtube_ugc</td>\n",
166
+ " <td>0.8325</td>\n",
167
+ " <td>0.8548</td>\n",
168
+ " </tr>\n",
169
+ " </tbody>\n",
170
+ "</table>\n",
171
+ "</div>"
172
+ ]
173
+ },
174
+ "execution_count": 21,
175
+ "metadata": {},
176
+ "output_type": "execute_result"
177
+ }
178
+ ],
179
+ "execution_count": 21
180
+ },
181
+ {
182
+ "metadata": {},
183
+ "cell_type": "markdown",
184
+ "source": "Cross-dataset evaluation: Train on <trainOn> dataset and test on \"testOn\" dataset",
185
+ "id": "68dbaf20d93a49ed"
186
+ },
187
+ {
188
+ "metadata": {
189
+ "ExecuteTime": {
190
+ "end_time": "2026-05-19T12:42:02.318434Z",
191
+ "start_time": "2026-05-19T12:42:02.316713Z"
192
+ }
193
+ },
194
+ "cell_type": "code",
195
+ "source": "LOG_ROOT = Path(\"./checkpoints_transfer\")",
196
+ "id": "e8583a08cc496a7c",
197
+ "outputs": [],
198
+ "execution_count": 22
199
+ },
200
+ {
201
+ "metadata": {
202
+ "ExecuteTime": {
203
+ "end_time": "2026-05-19T12:42:02.328167Z",
204
+ "start_time": "2026-05-19T12:42:02.319595Z"
205
+ }
206
+ },
207
+ "cell_type": "code",
208
+ "source": [
209
+ "filename_pattern = re.compile(r\"^transfer_(?P<trainOn>.+?)_test_on_(?P<testOn>.+?)\\.log$\")\n",
210
+ "result_pattern = re.compile(\n",
211
+ " r\"TEST_ONLY\\s*\\|\\s*\"\n",
212
+ " r\"loss=(?P<loss>[-+]?\\d*\\.?\\d+)\\s+\"\n",
213
+ " r\"plcc=(?P<plcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
214
+ " r\"srcc=(?P<srcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
215
+ " r\"rmse=(?P<rmse>[-+]?\\d*\\.?\\d+)\"\n",
216
+ ")\n",
217
+ "records = []\n",
218
+ "for log_path in LOG_ROOT.rglob(\"transfer_*_test_on_*.log\"):\n",
219
+ " filename_match = filename_pattern.match(log_path.name)\n",
220
+ " trainOn = filename_match.group(\"trainOn\")\n",
221
+ " testOn = filename_match.group(\"testOn\")\n",
222
+ "\n",
223
+ " text = log_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n",
224
+ " matches = list(result_pattern.finditer(text))\n",
225
+ " if not matches:\n",
226
+ " continue\n",
227
+ " records.append({\n",
228
+ " \"trainOn\": trainOn,\n",
229
+ " \"testOn\": testOn,\n",
230
+ " \"srcc\": float(matches[-1].group(\"srcc\")),\n",
231
+ " \"plcc\": float(matches[-1].group(\"plcc\")),\n",
232
+ " })\n",
233
+ "transfer_df = (pd.DataFrame(records).sort_values([\"trainOn\", \"testOn\"], ignore_index=True))\n",
234
+ "transfer_df"
235
+ ],
236
+ "id": "b196c863f89e0e0d",
237
+ "outputs": [
238
+ {
239
+ "data": {
240
+ "text/plain": [
241
+ " trainOn testOn srcc plcc\n",
242
+ "0 kvq shorts-hdr-dataset_hdr2sdr 0.4761 0.5885\n",
243
+ "1 kvq shorts-hdr-dataset_sdr 0.7548 0.8065\n",
244
+ "2 shorts-hdr-dataset_hdr2sdr kvq 0.5976 0.5693\n",
245
+ "3 shorts-hdr-dataset_hdr2sdr shorts-hdr-dataset_sdr 0.6166 0.6802\n",
246
+ "4 shorts-hdr-dataset_sdr kvq 0.7342 0.7451\n",
247
+ "5 shorts-hdr-dataset_sdr shorts-hdr-dataset_hdr2sdr 0.5447 0.6958"
248
+ ],
249
+ "text/html": [
250
+ "<div>\n",
251
+ "<style scoped>\n",
252
+ " .dataframe tbody tr th:only-of-type {\n",
253
+ " vertical-align: middle;\n",
254
+ " }\n",
255
+ "\n",
256
+ " .dataframe tbody tr th {\n",
257
+ " vertical-align: top;\n",
258
+ " }\n",
259
+ "\n",
260
+ " .dataframe thead th {\n",
261
+ " text-align: right;\n",
262
+ " }\n",
263
+ "</style>\n",
264
+ "<table border=\"1\" class=\"dataframe\">\n",
265
+ " <thead>\n",
266
+ " <tr style=\"text-align: right;\">\n",
267
+ " <th></th>\n",
268
+ " <th>trainOn</th>\n",
269
+ " <th>testOn</th>\n",
270
+ " <th>srcc</th>\n",
271
+ " <th>plcc</th>\n",
272
+ " </tr>\n",
273
+ " </thead>\n",
274
+ " <tbody>\n",
275
+ " <tr>\n",
276
+ " <th>0</th>\n",
277
+ " <td>kvq</td>\n",
278
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
279
+ " <td>0.4761</td>\n",
280
+ " <td>0.5885</td>\n",
281
+ " </tr>\n",
282
+ " <tr>\n",
283
+ " <th>1</th>\n",
284
+ " <td>kvq</td>\n",
285
+ " <td>shorts-hdr-dataset_sdr</td>\n",
286
+ " <td>0.7548</td>\n",
287
+ " <td>0.8065</td>\n",
288
+ " </tr>\n",
289
+ " <tr>\n",
290
+ " <th>2</th>\n",
291
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
292
+ " <td>kvq</td>\n",
293
+ " <td>0.5976</td>\n",
294
+ " <td>0.5693</td>\n",
295
+ " </tr>\n",
296
+ " <tr>\n",
297
+ " <th>3</th>\n",
298
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
299
+ " <td>shorts-hdr-dataset_sdr</td>\n",
300
+ " <td>0.6166</td>\n",
301
+ " <td>0.6802</td>\n",
302
+ " </tr>\n",
303
+ " <tr>\n",
304
+ " <th>4</th>\n",
305
+ " <td>shorts-hdr-dataset_sdr</td>\n",
306
+ " <td>kvq</td>\n",
307
+ " <td>0.7342</td>\n",
308
+ " <td>0.7451</td>\n",
309
+ " </tr>\n",
310
+ " <tr>\n",
311
+ " <th>5</th>\n",
312
+ " <td>shorts-hdr-dataset_sdr</td>\n",
313
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
314
+ " <td>0.5447</td>\n",
315
+ " <td>0.6958</td>\n",
316
+ " </tr>\n",
317
+ " </tbody>\n",
318
+ "</table>\n",
319
+ "</div>"
320
+ ]
321
+ },
322
+ "execution_count": 23,
323
+ "metadata": {},
324
+ "output_type": "execute_result"
325
+ }
326
+ ],
327
+ "execution_count": 23
328
+ },
329
+ {
330
+ "metadata": {},
331
+ "cell_type": "markdown",
332
+ "source": "Cross-dataset evaluation: Train on \"trainOn\" dataset and finetune on \"fintuneOn\" dataset",
333
+ "id": "e08409c56d2eaea3"
334
+ },
335
+ {
336
+ "metadata": {
337
+ "ExecuteTime": {
338
+ "end_time": "2026-05-19T12:42:02.337383Z",
339
+ "start_time": "2026-05-19T12:42:02.329308Z"
340
+ }
341
+ },
342
+ "cell_type": "code",
343
+ "source": [
344
+ "filename_pattern = re.compile(r\"^transfer_(?P<trainOn>.+?)_finetune_on_(?P<finetuneOn>.+?)\\.log$\")\n",
345
+ "finetune_result_pattern = re.compile(\n",
346
+ " r\"FINETUNE TEST\\s*\\|\\s*\"\n",
347
+ " r\"loss=(?P<loss>[-+]?\\d*\\.?\\d+)\\s+\"\n",
348
+ " r\"plcc=(?P<plcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
349
+ " r\"srcc=(?P<srcc>[-+]?\\d*\\.?\\d+)\\s+\"\n",
350
+ " r\"rmse=(?P<rmse>[-+]?\\d*\\.?\\d+)\"\n",
351
+ ")\n",
352
+ "records = []\n",
353
+ "for log_path in LOG_ROOT.rglob(\"transfer_*_finetune_on_*.log\"):\n",
354
+ " filename_match = filename_pattern.match(log_path.name)\n",
355
+ " trainOn = filename_match.group(\"trainOn\")\n",
356
+ " finetuneOn = filename_match.group(\"finetuneOn\")\n",
357
+ "\n",
358
+ " text = log_path.read_text(encoding=\"utf-8\", errors=\"ignore\")\n",
359
+ " matches = list(finetune_result_pattern.finditer(text))\n",
360
+ " if not matches:\n",
361
+ " continue\n",
362
+ " records.append({\n",
363
+ " \"trainOn\": trainOn,\n",
364
+ " \"finetuneOn\": finetuneOn,\n",
365
+ " \"srcc\": float(matches[-1].group(\"srcc\")),\n",
366
+ " \"plcc\": float(matches[-1].group(\"plcc\")),\n",
367
+ " })\n",
368
+ "finetune_df = (pd.DataFrame(records).sort_values([\"trainOn\", \"finetuneOn\"], ignore_index=True))\n",
369
+ "finetune_df"
370
+ ],
371
+ "id": "bfe4a5008377348d",
372
+ "outputs": [
373
+ {
374
+ "data": {
375
+ "text/plain": [
376
+ " trainOn finetuneOn srcc plcc\n",
377
+ "0 kvq shorts-hdr-dataset_hdr2sdr 0.6412 0.7225\n",
378
+ "1 kvq shorts-hdr-dataset_sdr 0.8291 0.8683\n",
379
+ "2 shorts-hdr-dataset_hdr2sdr kvq 0.8743 0.8792\n",
380
+ "3 shorts-hdr-dataset_hdr2sdr shorts-hdr-dataset_sdr 0.8183 0.8561\n",
381
+ "4 shorts-hdr-dataset_sdr kvq 0.8862 0.8883\n",
382
+ "5 shorts-hdr-dataset_sdr shorts-hdr-dataset_hdr2sdr 0.6589 0.8013"
383
+ ],
384
+ "text/html": [
385
+ "<div>\n",
386
+ "<style scoped>\n",
387
+ " .dataframe tbody tr th:only-of-type {\n",
388
+ " vertical-align: middle;\n",
389
+ " }\n",
390
+ "\n",
391
+ " .dataframe tbody tr th {\n",
392
+ " vertical-align: top;\n",
393
+ " }\n",
394
+ "\n",
395
+ " .dataframe thead th {\n",
396
+ " text-align: right;\n",
397
+ " }\n",
398
+ "</style>\n",
399
+ "<table border=\"1\" class=\"dataframe\">\n",
400
+ " <thead>\n",
401
+ " <tr style=\"text-align: right;\">\n",
402
+ " <th></th>\n",
403
+ " <th>trainOn</th>\n",
404
+ " <th>finetuneOn</th>\n",
405
+ " <th>srcc</th>\n",
406
+ " <th>plcc</th>\n",
407
+ " </tr>\n",
408
+ " </thead>\n",
409
+ " <tbody>\n",
410
+ " <tr>\n",
411
+ " <th>0</th>\n",
412
+ " <td>kvq</td>\n",
413
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
414
+ " <td>0.6412</td>\n",
415
+ " <td>0.7225</td>\n",
416
+ " </tr>\n",
417
+ " <tr>\n",
418
+ " <th>1</th>\n",
419
+ " <td>kvq</td>\n",
420
+ " <td>shorts-hdr-dataset_sdr</td>\n",
421
+ " <td>0.8291</td>\n",
422
+ " <td>0.8683</td>\n",
423
+ " </tr>\n",
424
+ " <tr>\n",
425
+ " <th>2</th>\n",
426
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
427
+ " <td>kvq</td>\n",
428
+ " <td>0.8743</td>\n",
429
+ " <td>0.8792</td>\n",
430
+ " </tr>\n",
431
+ " <tr>\n",
432
+ " <th>3</th>\n",
433
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
434
+ " <td>shorts-hdr-dataset_sdr</td>\n",
435
+ " <td>0.8183</td>\n",
436
+ " <td>0.8561</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <th>4</th>\n",
440
+ " <td>shorts-hdr-dataset_sdr</td>\n",
441
+ " <td>kvq</td>\n",
442
+ " <td>0.8862</td>\n",
443
+ " <td>0.8883</td>\n",
444
+ " </tr>\n",
445
+ " <tr>\n",
446
+ " <th>5</th>\n",
447
+ " <td>shorts-hdr-dataset_sdr</td>\n",
448
+ " <td>shorts-hdr-dataset_hdr2sdr</td>\n",
449
+ " <td>0.6589</td>\n",
450
+ " <td>0.8013</td>\n",
451
+ " </tr>\n",
452
+ " </tbody>\n",
453
+ "</table>\n",
454
+ "</div>"
455
+ ]
456
+ },
457
+ "execution_count": 24,
458
+ "metadata": {},
459
+ "output_type": "execute_result"
460
+ }
461
+ ],
462
+ "execution_count": 24
463
+ },
464
+ {
465
+ "metadata": {},
466
+ "cell_type": "markdown",
467
+ "source": "Complexity_test",
468
+ "id": "544f3bed13c8f3d9"
469
+ },
470
+ {
471
+ "metadata": {
472
+ "ExecuteTime": {
473
+ "end_time": "2026-05-19T12:42:02.339580Z",
474
+ "start_time": "2026-05-19T12:42:02.337969Z"
475
+ }
476
+ },
477
+ "cell_type": "code",
478
+ "source": "complexity_csv_path = \"../test_videos/complexity_test/\"",
479
+ "id": "9d718a191ee6bdf2",
480
+ "outputs": [],
481
+ "execution_count": 25
482
+ },
483
+ {
484
+ "metadata": {
485
+ "ExecuteTime": {
486
+ "end_time": "2026-05-19T12:42:02.347488Z",
487
+ "start_time": "2026-05-19T12:42:02.340191Z"
488
+ }
489
+ },
490
+ "cell_type": "code",
491
+ "source": [
492
+ "resolution_df1 = pd.read_csv(f\"{complexity_csv_path}complexity_resolution_KVQ_0546.csv\").assign(video=\"KVQ_0546\")\n",
493
+ "resolution_df2 = pd.read_csv(f\"{complexity_csv_path}complexity_resolution_SDR_Animal_5ngj.csv\").assign(video=\"SDR_Animal_5ngj\")\n",
494
+ "resolution_df = pd.concat([resolution_df1, resolution_df2], ignore_index=True)\n",
495
+ "resolution_df"
496
+ ],
497
+ "id": "5c99ce5680f78c34",
498
+ "outputs": [
499
+ {
500
+ "data": {
501
+ "text/plain": [
502
+ " model 540P 720P 1080P 2160P MOS video\n",
503
+ "0 FAST-VQA 0.638 0.745 1.020 2.680 4.017 KVQ_0546\n",
504
+ "1 FasterVQA 0.428 0.480 0.573 1.077 4.356 KVQ_0546\n",
505
+ "2 DOVER 1.018 1.163 1.511 3.662 4.223 KVQ_0546\n",
506
+ "3 Our 0.364 0.485 0.753 2.437 3.650 KVQ_0546\n",
507
+ "4 FAST-VQA 0.599 0.673 0.909 2.217 3.319 SDR_Animal_5ngj\n",
508
+ "5 FasterVQA 0.489 0.547 0.696 1.343 3.556 SDR_Animal_5ngj\n",
509
+ "6 DOVER 0.920 1.022 1.293 2.783 3.814 SDR_Animal_5ngj\n",
510
+ "7 Our 0.313 0.405 0.697 2.137 3.878 SDR_Animal_5ngj"
511
+ ],
512
+ "text/html": [
513
+ "<div>\n",
514
+ "<style scoped>\n",
515
+ " .dataframe tbody tr th:only-of-type {\n",
516
+ " vertical-align: middle;\n",
517
+ " }\n",
518
+ "\n",
519
+ " .dataframe tbody tr th {\n",
520
+ " vertical-align: top;\n",
521
+ " }\n",
522
+ "\n",
523
+ " .dataframe thead th {\n",
524
+ " text-align: right;\n",
525
+ " }\n",
526
+ "</style>\n",
527
+ "<table border=\"1\" class=\"dataframe\">\n",
528
+ " <thead>\n",
529
+ " <tr style=\"text-align: right;\">\n",
530
+ " <th></th>\n",
531
+ " <th>model</th>\n",
532
+ " <th>540P</th>\n",
533
+ " <th>720P</th>\n",
534
+ " <th>1080P</th>\n",
535
+ " <th>2160P</th>\n",
536
+ " <th>MOS</th>\n",
537
+ " <th>video</th>\n",
538
+ " </tr>\n",
539
+ " </thead>\n",
540
+ " <tbody>\n",
541
+ " <tr>\n",
542
+ " <th>0</th>\n",
543
+ " <td>FAST-VQA</td>\n",
544
+ " <td>0.638</td>\n",
545
+ " <td>0.745</td>\n",
546
+ " <td>1.020</td>\n",
547
+ " <td>2.680</td>\n",
548
+ " <td>4.017</td>\n",
549
+ " <td>KVQ_0546</td>\n",
550
+ " </tr>\n",
551
+ " <tr>\n",
552
+ " <th>1</th>\n",
553
+ " <td>FasterVQA</td>\n",
554
+ " <td>0.428</td>\n",
555
+ " <td>0.480</td>\n",
556
+ " <td>0.573</td>\n",
557
+ " <td>1.077</td>\n",
558
+ " <td>4.356</td>\n",
559
+ " <td>KVQ_0546</td>\n",
560
+ " </tr>\n",
561
+ " <tr>\n",
562
+ " <th>2</th>\n",
563
+ " <td>DOVER</td>\n",
564
+ " <td>1.018</td>\n",
565
+ " <td>1.163</td>\n",
566
+ " <td>1.511</td>\n",
567
+ " <td>3.662</td>\n",
568
+ " <td>4.223</td>\n",
569
+ " <td>KVQ_0546</td>\n",
570
+ " </tr>\n",
571
+ " <tr>\n",
572
+ " <th>3</th>\n",
573
+ " <td>Our</td>\n",
574
+ " <td>0.364</td>\n",
575
+ " <td>0.485</td>\n",
576
+ " <td>0.753</td>\n",
577
+ " <td>2.437</td>\n",
578
+ " <td>3.650</td>\n",
579
+ " <td>KVQ_0546</td>\n",
580
+ " </tr>\n",
581
+ " <tr>\n",
582
+ " <th>4</th>\n",
583
+ " <td>FAST-VQA</td>\n",
584
+ " <td>0.599</td>\n",
585
+ " <td>0.673</td>\n",
586
+ " <td>0.909</td>\n",
587
+ " <td>2.217</td>\n",
588
+ " <td>3.319</td>\n",
589
+ " <td>SDR_Animal_5ngj</td>\n",
590
+ " </tr>\n",
591
+ " <tr>\n",
592
+ " <th>5</th>\n",
593
+ " <td>FasterVQA</td>\n",
594
+ " <td>0.489</td>\n",
595
+ " <td>0.547</td>\n",
596
+ " <td>0.696</td>\n",
597
+ " <td>1.343</td>\n",
598
+ " <td>3.556</td>\n",
599
+ " <td>SDR_Animal_5ngj</td>\n",
600
+ " </tr>\n",
601
+ " <tr>\n",
602
+ " <th>6</th>\n",
603
+ " <td>DOVER</td>\n",
604
+ " <td>0.920</td>\n",
605
+ " <td>1.022</td>\n",
606
+ " <td>1.293</td>\n",
607
+ " <td>2.783</td>\n",
608
+ " <td>3.814</td>\n",
609
+ " <td>SDR_Animal_5ngj</td>\n",
610
+ " </tr>\n",
611
+ " <tr>\n",
612
+ " <th>7</th>\n",
613
+ " <td>Our</td>\n",
614
+ " <td>0.313</td>\n",
615
+ " <td>0.405</td>\n",
616
+ " <td>0.697</td>\n",
617
+ " <td>2.137</td>\n",
618
+ " <td>3.878</td>\n",
619
+ " <td>SDR_Animal_5ngj</td>\n",
620
+ " </tr>\n",
621
+ " </tbody>\n",
622
+ "</table>\n",
623
+ "</div>"
624
+ ]
625
+ },
626
+ "execution_count": 26,
627
+ "metadata": {},
628
+ "output_type": "execute_result"
629
+ }
630
+ ],
631
+ "execution_count": 26
632
+ },
633
+ {
634
+ "metadata": {
635
+ "ExecuteTime": {
636
+ "end_time": "2026-05-19T12:42:02.353757Z",
637
+ "start_time": "2026-05-19T12:42:02.348243Z"
638
+ }
639
+ },
640
+ "cell_type": "code",
641
+ "source": [
642
+ "complexity_df = pd.read_csv(f\"{complexity_csv_path}complexity_test.csv\")\n",
643
+ "complexity_df"
644
+ ],
645
+ "id": "f639ed0f66a50c57",
646
+ "outputs": [
647
+ {
648
+ "data": {
649
+ "text/plain": [
650
+ " model video Time Predicted_MOS Scaled_MOS MOS\n",
651
+ "0 FAST-VQA SDR_Gameplay_wcq2 0.822 0.248790 1.995 4.156\n",
652
+ "1 FAST-VQA SDR_Gameplay_s0pc 0.820 0.550190 3.201 4.264\n",
653
+ "2 FAST-VQA SDR_Gameplay_z5fz 0.823 0.538800 3.155 4.221\n",
654
+ "3 FAST-VQA SDR_Animal_5ngj 0.846 0.579850 3.319 4.308\n",
655
+ "4 FAST-VQA 0546 0.895 0.754290 4.017 2.389\n",
656
+ "5 FasterVQA SDR_Gameplay_wcq2 0.502 0.299890 2.200 4.156\n",
657
+ "6 FasterVQA SDR_Gameplay_s0pc 0.524 0.706330 3.825 4.264\n",
658
+ "7 FasterVQA SDR_Gameplay_z5fz 0.499 0.400900 2.604 4.221\n",
659
+ "8 FasterVQA SDR_Animal_5ngj 0.500 0.638990 3.556 4.308\n",
660
+ "9 FasterVQA 0546 0.484 0.839070 4.356 2.389\n",
661
+ "10 DOVER SDR_Gameplay_wcq2 1.407 0.347305 2.389 4.156\n",
662
+ "11 DOVER SDR_Gameplay_s0pc 1.345 0.582339 3.329 4.264\n",
663
+ "12 DOVER SDR_Gameplay_z5fz 1.341 0.524504 3.098 4.221\n",
664
+ "13 DOVER SDR_Animal_5ngj 1.326 0.703516 3.814 4.308\n",
665
+ "14 DOVER 0546 1.489 0.805786 4.223 2.389\n",
666
+ "15 Our SDR_Gameplay_wcq2 0.594 59.864200 3.395 4.156\n",
667
+ "16 Our SDR_Gameplay_s0pc 0.605 61.772900 3.471 4.264\n",
668
+ "17 Our SDR_Gameplay_z5fz 0.602 61.077000 3.443 4.221\n",
669
+ "18 Our SDR_Animal_5ngj 0.598 71.957800 3.878 4.308\n",
670
+ "19 Our 0546 0.734 66.257200 3.650 2.389"
671
+ ],
672
+ "text/html": [
673
+ "<div>\n",
674
+ "<style scoped>\n",
675
+ " .dataframe tbody tr th:only-of-type {\n",
676
+ " vertical-align: middle;\n",
677
+ " }\n",
678
+ "\n",
679
+ " .dataframe tbody tr th {\n",
680
+ " vertical-align: top;\n",
681
+ " }\n",
682
+ "\n",
683
+ " .dataframe thead th {\n",
684
+ " text-align: right;\n",
685
+ " }\n",
686
+ "</style>\n",
687
+ "<table border=\"1\" class=\"dataframe\">\n",
688
+ " <thead>\n",
689
+ " <tr style=\"text-align: right;\">\n",
690
+ " <th></th>\n",
691
+ " <th>model</th>\n",
692
+ " <th>video</th>\n",
693
+ " <th>Time</th>\n",
694
+ " <th>Predicted_MOS</th>\n",
695
+ " <th>Scaled_MOS</th>\n",
696
+ " <th>MOS</th>\n",
697
+ " </tr>\n",
698
+ " </thead>\n",
699
+ " <tbody>\n",
700
+ " <tr>\n",
701
+ " <th>0</th>\n",
702
+ " <td>FAST-VQA</td>\n",
703
+ " <td>SDR_Gameplay_wcq2</td>\n",
704
+ " <td>0.822</td>\n",
705
+ " <td>0.248790</td>\n",
706
+ " <td>1.995</td>\n",
707
+ " <td>4.156</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <th>1</th>\n",
711
+ " <td>FAST-VQA</td>\n",
712
+ " <td>SDR_Gameplay_s0pc</td>\n",
713
+ " <td>0.820</td>\n",
714
+ " <td>0.550190</td>\n",
715
+ " <td>3.201</td>\n",
716
+ " <td>4.264</td>\n",
717
+ " </tr>\n",
718
+ " <tr>\n",
719
+ " <th>2</th>\n",
720
+ " <td>FAST-VQA</td>\n",
721
+ " <td>SDR_Gameplay_z5fz</td>\n",
722
+ " <td>0.823</td>\n",
723
+ " <td>0.538800</td>\n",
724
+ " <td>3.155</td>\n",
725
+ " <td>4.221</td>\n",
726
+ " </tr>\n",
727
+ " <tr>\n",
728
+ " <th>3</th>\n",
729
+ " <td>FAST-VQA</td>\n",
730
+ " <td>SDR_Animal_5ngj</td>\n",
731
+ " <td>0.846</td>\n",
732
+ " <td>0.579850</td>\n",
733
+ " <td>3.319</td>\n",
734
+ " <td>4.308</td>\n",
735
+ " </tr>\n",
736
+ " <tr>\n",
737
+ " <th>4</th>\n",
738
+ " <td>FAST-VQA</td>\n",
739
+ " <td>0546</td>\n",
740
+ " <td>0.895</td>\n",
741
+ " <td>0.754290</td>\n",
742
+ " <td>4.017</td>\n",
743
+ " <td>2.389</td>\n",
744
+ " </tr>\n",
745
+ " <tr>\n",
746
+ " <th>5</th>\n",
747
+ " <td>FasterVQA</td>\n",
748
+ " <td>SDR_Gameplay_wcq2</td>\n",
749
+ " <td>0.502</td>\n",
750
+ " <td>0.299890</td>\n",
751
+ " <td>2.200</td>\n",
752
+ " <td>4.156</td>\n",
753
+ " </tr>\n",
754
+ " <tr>\n",
755
+ " <th>6</th>\n",
756
+ " <td>FasterVQA</td>\n",
757
+ " <td>SDR_Gameplay_s0pc</td>\n",
758
+ " <td>0.524</td>\n",
759
+ " <td>0.706330</td>\n",
760
+ " <td>3.825</td>\n",
761
+ " <td>4.264</td>\n",
762
+ " </tr>\n",
763
+ " <tr>\n",
764
+ " <th>7</th>\n",
765
+ " <td>FasterVQA</td>\n",
766
+ " <td>SDR_Gameplay_z5fz</td>\n",
767
+ " <td>0.499</td>\n",
768
+ " <td>0.400900</td>\n",
769
+ " <td>2.604</td>\n",
770
+ " <td>4.221</td>\n",
771
+ " </tr>\n",
772
+ " <tr>\n",
773
+ " <th>8</th>\n",
774
+ " <td>FasterVQA</td>\n",
775
+ " <td>SDR_Animal_5ngj</td>\n",
776
+ " <td>0.500</td>\n",
777
+ " <td>0.638990</td>\n",
778
+ " <td>3.556</td>\n",
779
+ " <td>4.308</td>\n",
780
+ " </tr>\n",
781
+ " <tr>\n",
782
+ " <th>9</th>\n",
783
+ " <td>FasterVQA</td>\n",
784
+ " <td>0546</td>\n",
785
+ " <td>0.484</td>\n",
786
+ " <td>0.839070</td>\n",
787
+ " <td>4.356</td>\n",
788
+ " <td>2.389</td>\n",
789
+ " </tr>\n",
790
+ " <tr>\n",
791
+ " <th>10</th>\n",
792
+ " <td>DOVER</td>\n",
793
+ " <td>SDR_Gameplay_wcq2</td>\n",
794
+ " <td>1.407</td>\n",
795
+ " <td>0.347305</td>\n",
796
+ " <td>2.389</td>\n",
797
+ " <td>4.156</td>\n",
798
+ " </tr>\n",
799
+ " <tr>\n",
800
+ " <th>11</th>\n",
801
+ " <td>DOVER</td>\n",
802
+ " <td>SDR_Gameplay_s0pc</td>\n",
803
+ " <td>1.345</td>\n",
804
+ " <td>0.582339</td>\n",
805
+ " <td>3.329</td>\n",
806
+ " <td>4.264</td>\n",
807
+ " </tr>\n",
808
+ " <tr>\n",
809
+ " <th>12</th>\n",
810
+ " <td>DOVER</td>\n",
811
+ " <td>SDR_Gameplay_z5fz</td>\n",
812
+ " <td>1.341</td>\n",
813
+ " <td>0.524504</td>\n",
814
+ " <td>3.098</td>\n",
815
+ " <td>4.221</td>\n",
816
+ " </tr>\n",
817
+ " <tr>\n",
818
+ " <th>13</th>\n",
819
+ " <td>DOVER</td>\n",
820
+ " <td>SDR_Animal_5ngj</td>\n",
821
+ " <td>1.326</td>\n",
822
+ " <td>0.703516</td>\n",
823
+ " <td>3.814</td>\n",
824
+ " <td>4.308</td>\n",
825
+ " </tr>\n",
826
+ " <tr>\n",
827
+ " <th>14</th>\n",
828
+ " <td>DOVER</td>\n",
829
+ " <td>0546</td>\n",
830
+ " <td>1.489</td>\n",
831
+ " <td>0.805786</td>\n",
832
+ " <td>4.223</td>\n",
833
+ " <td>2.389</td>\n",
834
+ " </tr>\n",
835
+ " <tr>\n",
836
+ " <th>15</th>\n",
837
+ " <td>Our</td>\n",
838
+ " <td>SDR_Gameplay_wcq2</td>\n",
839
+ " <td>0.594</td>\n",
840
+ " <td>59.864200</td>\n",
841
+ " <td>3.395</td>\n",
842
+ " <td>4.156</td>\n",
843
+ " </tr>\n",
844
+ " <tr>\n",
845
+ " <th>16</th>\n",
846
+ " <td>Our</td>\n",
847
+ " <td>SDR_Gameplay_s0pc</td>\n",
848
+ " <td>0.605</td>\n",
849
+ " <td>61.772900</td>\n",
850
+ " <td>3.471</td>\n",
851
+ " <td>4.264</td>\n",
852
+ " </tr>\n",
853
+ " <tr>\n",
854
+ " <th>17</th>\n",
855
+ " <td>Our</td>\n",
856
+ " <td>SDR_Gameplay_z5fz</td>\n",
857
+ " <td>0.602</td>\n",
858
+ " <td>61.077000</td>\n",
859
+ " <td>3.443</td>\n",
860
+ " <td>4.221</td>\n",
861
+ " </tr>\n",
862
+ " <tr>\n",
863
+ " <th>18</th>\n",
864
+ " <td>Our</td>\n",
865
+ " <td>SDR_Animal_5ngj</td>\n",
866
+ " <td>0.598</td>\n",
867
+ " <td>71.957800</td>\n",
868
+ " <td>3.878</td>\n",
869
+ " <td>4.308</td>\n",
870
+ " </tr>\n",
871
+ " <tr>\n",
872
+ " <th>19</th>\n",
873
+ " <td>Our</td>\n",
874
+ " <td>0546</td>\n",
875
+ " <td>0.734</td>\n",
876
+ " <td>66.257200</td>\n",
877
+ " <td>3.650</td>\n",
878
+ " <td>2.389</td>\n",
879
+ " </tr>\n",
880
+ " </tbody>\n",
881
+ "</table>\n",
882
+ "</div>"
883
+ ]
884
+ },
885
+ "execution_count": 27,
886
+ "metadata": {},
887
+ "output_type": "execute_result"
888
+ }
889
+ ],
890
+ "execution_count": 27
891
+ }
892
+ ],
893
+ "metadata": {
894
+ "kernelspec": {
895
+ "display_name": "Python 3",
896
+ "language": "python",
897
+ "name": "python3"
898
+ },
899
+ "language_info": {
900
+ "codemirror_mode": {
901
+ "name": "ipython",
902
+ "version": 2
903
+ },
904
+ "file_extension": ".py",
905
+ "mimetype": "text/x-python",
906
+ "name": "python",
907
+ "nbconvert_exporter": "python",
908
+ "pygments_lexer": "ipython2",
909
+ "version": "2.7.6"
910
+ }
911
+ },
912
+ "nbformat": 4,
913
+ "nbformat_minor": 5
914
+ }
src/demo_test.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ.setdefault("DECORD_DUPLICATE_WARNING_THRESHOLD", "1.0")
3
+
4
+ import argparse
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.amp import autocast
12
+ from tqdm import tqdm
13
+ from thop import profile
14
+ from thop import clever_format
15
+
16
+ from train import VQADataset
17
+ from model.qd_model import QD_MODEL
18
+
19
+
20
+ def load_checkpoint(ckpt_path, device):
21
+ ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=True)
22
+ if isinstance(ckpt, dict) and "model" in ckpt:
23
+ return {
24
+ "state_dict": ckpt["model"],
25
+ "train_mos_mean": ckpt.get("mos_mean"),
26
+ "train_mos_std": ckpt.get("mos_std"),
27
+ "train_args": ckpt.get("args", {}),
28
+ "is_full_checkpoint": True,
29
+ }
30
+ if isinstance(ckpt, dict):
31
+ return {
32
+ "state_dict": ckpt,
33
+ "train_mos_mean": None,
34
+ "train_mos_std": None,
35
+ "train_args": {},
36
+ "is_full_checkpoint": False,
37
+ }
38
+ raise TypeError(f"Unsupported checkpoint type: {type(ckpt)!r}")
39
+
40
+ class ForwardWrapper(nn.Module):
41
+ def __init__(self, model):
42
+ super().__init__()
43
+ self.model = model
44
+
45
+ def forward(self, rgb, w_art, w_str):
46
+ yhat, _aux = self.model(rgb, w_art, w_str)
47
+ return yhat
48
+
49
+ def count_parameters(model):
50
+ total = sum(p.numel() for p in model.parameters())
51
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
52
+ return total, trainable
53
+
54
+ def build_dataset(args, mos_mean, mos_std):
55
+ rows = [(args.video_id, float(args.dummy_mos))]
56
+ dataset = VQADataset(
57
+ rows,
58
+ args.db_path,
59
+ clip_len=args.clip_len,
60
+ size=args.resize,
61
+ win=args.win,
62
+ win_step=args.win_step,
63
+ mos_mean=float(mos_mean),
64
+ mos_std=float(mos_std),
65
+ )
66
+ return dataset
67
+
68
+ def prepare_single_sample(dataset, device):
69
+ rgb, w_art, w_str, y, vid = dataset[0]
70
+
71
+ rgb = rgb.unsqueeze(0).to(device, non_blocking=True)
72
+ w_art = w_art.unsqueeze(0).to(device, non_blocking=True)
73
+ w_str = w_str.unsqueeze(0).to(device, non_blocking=True)
74
+ y = y.unsqueeze(0).to(device, non_blocking=True).float()
75
+
76
+ if isinstance(vid, (list, tuple)):
77
+ video_id = vid[0]
78
+ else:
79
+ video_id = vid
80
+ return rgb, w_art, w_str, y, video_id
81
+
82
+
83
+ @torch.no_grad()
84
+ def predict_once(
85
+ model,
86
+ rgb,
87
+ w_art,
88
+ w_str,
89
+ *,
90
+ device,
91
+ amp,
92
+ train_mos_mean,
93
+ train_mos_std,
94
+ ):
95
+ model.eval()
96
+ device_type = "cuda" if str(device).startswith("cuda") else "cpu"
97
+ with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
98
+ yhat, _aux = model(rgb, w_art, w_str)
99
+
100
+ pred_score = yhat.detach().float().cpu() * float(train_mos_std) + float(train_mos_mean)
101
+ return float(pred_score.squeeze().item())
102
+
103
+ @torch.no_grad()
104
+ def profile_with_thop(model, rgb, w_art, w_str):
105
+ macs, params = profile(model, inputs=(rgb, w_art, w_str), verbose=False)
106
+ flops = 2 * macs
107
+ macs, flops, params = clever_format([macs, flops, params], "%.3f")
108
+ return macs, flops, params
109
+
110
+ @torch.no_grad()
111
+ def benchmark_forward(model, rgb, w_art, w_str, *, device, amp, num_runs=10, warmup=3):
112
+ model.eval()
113
+ device_type = "cuda" if str(device).startswith("cuda") else "cpu"
114
+
115
+ for _ in range(max(0, warmup)):
116
+ with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
117
+ _ = model(rgb, w_art, w_str)
118
+ if device_type == "cuda":
119
+ torch.cuda.synchronize()
120
+
121
+ start = time.perf_counter()
122
+ for _ in range(int(num_runs)):
123
+ with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
124
+ _ = model(rgb, w_art, w_str)
125
+ if device_type == "cuda":
126
+ torch.cuda.synchronize()
127
+ return (time.perf_counter() - start) / max(1, int(num_runs))
128
+
129
+
130
+ def run_end_to_end_once(args, model, train_mos_mean, train_mos_std, device, amp):
131
+ start = time.perf_counter()
132
+ dataset = build_dataset(args, train_mos_mean, train_mos_std)
133
+ rgb, w_art, w_str, _y, _video_id = prepare_single_sample(dataset, device)
134
+
135
+ pred_score = predict_once(
136
+ model,
137
+ rgb,
138
+ w_art,
139
+ w_str,
140
+ device=device,
141
+ amp=amp,
142
+ train_mos_mean=float(train_mos_mean),
143
+ train_mos_std=float(train_mos_std),
144
+ )
145
+ if str(device).startswith("cuda"):
146
+ torch.cuda.synchronize()
147
+ elapsed = time.perf_counter() - start
148
+ return elapsed, pred_score, rgb, w_art, w_str
149
+
150
+
151
+ def parse_args():
152
+ ap = argparse.ArgumentParser(description="Demo-style single-video test for QD_MODEL")
153
+ # for complexity time test:
154
+ ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/lsvq/qd_model.best.pt")
155
+ ap.add_argument("--db_path", type=str, default="/home/xinyi/Project/FD-VQA/test_videos/")
156
+ ap.add_argument("--video_id", type=str, default="SDR_Animal_5ngj")
157
+ # for resolution compelxity test:
158
+ # ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/kvq/qd_model.best.pt")
159
+ # ap.add_argument("--db_path", type=str, default="/home/xinyi/Project/FD-VQA/test_videos/complexity_test/complexity_resolution/")
160
+ # ap.add_argument("--video_id", type=str, default="SDR_Animal_5ngj_540p")
161
+
162
+ ap.add_argument("--clip_len", type=int, default=16)
163
+ ap.add_argument("--resize", type=int, default=224)
164
+ ap.add_argument("--win", type=int, default=6)
165
+ ap.add_argument("--win_step", type=int, default=1)
166
+
167
+ ap.add_argument("--device", type=str, default="cuda")
168
+ ap.add_argument("--no_amp", action="store_true")
169
+
170
+ ap.add_argument("--dummy_mos", type=float, default=3.0, help="Only used to compute VQADataset, does not affect prediction")
171
+ ap.add_argument("--num_runs", type=int, default=10, help="Average N runs")
172
+ ap.add_argument("--warmup_runs", type=int, default=3)
173
+ ap.add_argument("--skip_profile", action="store_true")
174
+ return ap.parse_args()
175
+
176
+
177
+ def main():
178
+ args = parse_args()
179
+ device = torch.device(args.device)
180
+ amp = not bool(args.no_amp)
181
+
182
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
183
+
184
+ display_path = str(Path(args.db_path) / args.video_id)
185
+ info = pd.DataFrame([
186
+ {
187
+ "vid": args.video_id,
188
+ "test_video_path": display_path,
189
+ }
190
+ ])
191
+ print(info)
192
+
193
+ dataset_preview = build_dataset(args, mos_mean=args.dummy_mos, mos_std=1.0)
194
+ print(f"Dataset loaded. Total videos: {len(dataset_preview)}, Total batches: 1")
195
+ print(f"Loading model from: {args.ckpt_path}")
196
+
197
+ ckpt_info = load_checkpoint(Path(args.ckpt_path), device)
198
+ train_mos_mean = ckpt_info["train_mos_mean"]
199
+ train_mos_std = ckpt_info["train_mos_std"]
200
+ if train_mos_mean is None or train_mos_std is None:
201
+ raise ValueError("Checkpoint does not contain mos_mean / mos_std. Please use a full checkpoint.")
202
+ if float(train_mos_std) <= 1e-8:
203
+ raise ValueError("train_mos_std must be > 0")
204
+
205
+ model = QD_MODEL(clip_model="openai/clip-vit-base-patch16").to(device)
206
+ model.load_state_dict(ckpt_info["state_dict"], strict=True)
207
+ model.eval()
208
+
209
+ run_times = []
210
+ pred_score = None
211
+ rgb = w_art = w_str = None
212
+
213
+ for i in range(args.num_runs):
214
+ for _ in tqdm(range(1), desc="Processing Videos"):
215
+ elapsed, pred_score, rgb, w_art, w_str = run_end_to_end_once(
216
+ args, model, train_mos_mean, train_mos_std, device, amp
217
+ )
218
+ run_times.append(elapsed)
219
+ print(f"Run {i + 1} - Time taken: {elapsed:.4f} seconds")
220
+
221
+ avg_total_time = sum(run_times) / max(1, len(run_times))
222
+ avg_forward_time = benchmark_forward(
223
+ model,
224
+ rgb,
225
+ w_art,
226
+ w_str,
227
+ device=device,
228
+ amp=amp,
229
+ num_runs=args.num_runs,
230
+ warmup=args.warmup_runs,
231
+ )
232
+
233
+ total_params, trainable_params = count_parameters(model)
234
+ macs = flops = params = None
235
+ if not args.skip_profile:
236
+ try:
237
+ macs, flops, params = profile_with_thop(model, rgb, w_art, w_str)
238
+ except Exception as e:
239
+ print(f"[WARN] THOP profiling failed: {e}")
240
+
241
+ print(f"Average running time over {args.num_runs} runs: {avg_total_time:.4f} seconds")
242
+ print(f"Predicted Quality Score: {pred_score:.4f}")
243
+ print("\n========== PROFILE SUMMARY ==========")
244
+ print(f"video_id : {args.video_id}")
245
+ print(f"rgb shape : {tuple(rgb.shape)}")
246
+ print(f"w_art shape : {tuple(w_art.shape)}")
247
+ print(f"w_str shape : {tuple(w_str.shape)}")
248
+ print(f"train mos mean/std : {float(train_mos_mean):.6f} / {float(train_mos_std):.6f}")
249
+ print(f"predicted score : {pred_score:.6f}")
250
+ print(f"params total : {total_params:,} ({total_params / 1e6:.3f} M)")
251
+ print(f"params trainable : {trainable_params:,} ({trainable_params / 1e6:.3f} M)")
252
+ if params is not None:
253
+ print(f"Params (THOP) : {params} M")
254
+ if macs is not None:
255
+ print(f"MACs (THOP) : {macs} G")
256
+ if flops is not None:
257
+ print(f"FLOPs (~2*MACs) : {flops} G")
258
+ print(f"avg forward time : {avg_forward_time:.6f} s (runs={args.num_runs})")
259
+ print(f"avg end-to-end time : {avg_total_time:.6f} s (sample prep + forward)")
260
+ print("=====================================")
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()
src/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/model/__init__.py
src/model/clip_dense_encoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPModel
4
+ from transformers.utils import logging
5
+ logging.set_verbosity_error()
6
+
7
+
8
+ class CLIPDenseEncoder(nn.Module):
9
+ def __init__(self, model_name="openai/clip-vit-base-patch16"):
10
+ super().__init__()
11
+ self.clip = CLIPModel.from_pretrained(str(model_name), use_safetensors=True)
12
+
13
+ self.vision = self.clip.vision_model
14
+ vcfg = self.clip.config.vision_config
15
+ self.hidden_size = int(vcfg.hidden_size)
16
+ self.patch_size = int(vcfg.patch_size)
17
+ self.image_size = int(vcfg.image_size)
18
+
19
+ # OpenAI CLIP normalization constants
20
+ self.register_buffer("_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1))
21
+ self.register_buffer("_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1))
22
+
23
+ def forward(self, x01):
24
+ x = (x01 - self._mean.to(dtype=x01.dtype)) / self._std.to(dtype=x01.dtype)
25
+ out = self.vision(pixel_values=x)
26
+ tokens = out.last_hidden_state[:, 1:, :] # [B, 1+N, C] drop CLS -> [B, N, C]
27
+ b, n, c = tokens.shape
28
+ side = int(n**0.5)
29
+ if side * side != n:
30
+ raise RuntimeError(f"CLIP patch tokens N={n} not a square; cannot reshape to 2D grid.")
31
+
32
+ fmap = tokens.transpose(1, 2).contiguous().view(b, c, side, side)
33
+ return fmap
34
+
35
+ def freeze_all(self):
36
+ for p in self.clip.parameters():
37
+ p.requires_grad = False
38
+
39
+ def unfreeze_last_blocks(self, n_blocks=2, also_unfreeze_ln=True):
40
+ self.freeze_all()
41
+ layers = self.vision.encoder.layers
42
+ n_total = len(layers)
43
+
44
+ # all unfreeze
45
+ if n_blocks < 0 or n_blocks >= n_total:
46
+ for p in self.vision.parameters():
47
+ p.requires_grad = True
48
+ # unfreeze n blocks
49
+ else:
50
+ k = max(0, min(int(n_blocks), n_total))
51
+ for i in range(n_total - k, n_total):
52
+ for p in layers[i].parameters():
53
+ p.requires_grad = True
54
+
55
+ if also_unfreeze_ln:
56
+ if hasattr(self.vision, "pre_layrnorm"):
57
+ for p in self.vision.pre_layrnorm.parameters():
58
+ p.requires_grad = True
59
+ if hasattr(self.vision, "post_layernorm"):
60
+ for p in self.vision.post_layernorm.parameters():
61
+ p.requires_grad = True
src/model/qd_model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------
2
+ # Model: CLIP encoder + w_art/w_str pooling -> weighted sequences -> seq model -> head fusion
3
+ # ----------------------------
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .clip_dense_encoder import CLIPDenseEncoder
8
+
9
+
10
+ def weighted_pool_2d(fmap, wmap, eps=1e-6):
11
+ """
12
+ fmap: [B, C, H, W]
13
+ wmap: [B, 1, H, W]
14
+ -> [B, C]
15
+ """
16
+ w = wmap.clamp(0.0, 1.0)
17
+ w = w / (w.sum(dim=(2, 3), keepdim=True) + eps)
18
+ return (fmap * w).sum(dim=(2, 3))
19
+
20
+ def default_gate_stats(w_art, w_str, fmap_bt, eps=1e-6):
21
+ """
22
+ w_art: [B, 1, T, H, W]
23
+ w_str: [B, 1, T, H, W]
24
+ fmap_bt: [B, T, C, H, W]
25
+ -> [B, 3]
26
+ """
27
+ mu_art = w_art.mean(dim=(1, 2, 3, 4))
28
+ mu_str = w_str.mean(dim=(1, 2, 3, 4))
29
+ mu_raw = fmap_bt.abs().mean(dim=(1, 2, 3, 4))
30
+
31
+ stats = torch.stack([mu_art, mu_str, mu_raw], dim=1)
32
+ return stats
33
+
34
+ def two_gate_stats(w_art, w_str):
35
+ # [B, 2]
36
+ mu_art = w_art.mean(dim=(1, 2, 3, 4))
37
+ mu_str = w_str.mean(dim=(1, 2, 3, 4))
38
+ return torch.stack([mu_art, mu_str], dim=1)
39
+
40
+ class QD_MODEL(nn.Module):
41
+ def __init__(
42
+ self,
43
+ *,
44
+ clip_model="openai/clip-vit-base-patch16",
45
+ head_hidden=384,
46
+ gate_hidden=32,
47
+ head_dropout=0.2,
48
+ gate_dropout=0.1,
49
+ ablation_mode="full", # "full" | "art" | "str" | "raw"
50
+ ):
51
+ super().__init__()
52
+ self.ablation_mode = ablation_mode
53
+
54
+ self.encoder = CLIPDenseEncoder(model_name=str(clip_model))
55
+ c = int(self.encoder.hidden_size)
56
+
57
+ # 3 heads
58
+ self.head_art = nn.Sequential(
59
+ nn.Linear(c, head_hidden),
60
+ nn.ReLU(),
61
+ nn.Dropout(head_dropout),
62
+ nn.Linear(head_hidden, head_hidden // 2),
63
+ nn.ReLU(),
64
+ nn.Dropout(head_dropout),
65
+ nn.Linear(head_hidden // 2, 1),
66
+ )
67
+ self.head_str = nn.Sequential(
68
+ nn.Linear(c, head_hidden),
69
+ nn.ReLU(),
70
+ nn.Dropout(head_dropout),
71
+ nn.Linear(head_hidden, head_hidden // 2),
72
+ nn.ReLU(),
73
+ nn.Dropout(head_dropout),
74
+ nn.Linear(head_hidden // 2, 1),
75
+ )
76
+ self.head_fmap = nn.Sequential(
77
+ nn.Linear(c, head_hidden),
78
+ nn.ReLU(),
79
+ nn.Dropout(head_dropout),
80
+ nn.Linear(head_hidden, head_hidden // 2),
81
+ nn.ReLU(),
82
+ nn.Dropout(head_dropout),
83
+ nn.Linear(head_hidden // 2, 1),
84
+ )
85
+ # full model: 3-way gate
86
+ self.gate = nn.Sequential(
87
+ nn.Linear(3, gate_hidden),
88
+ nn.ReLU(),
89
+ nn.Dropout(gate_dropout),
90
+ nn.Linear(gate_hidden, 3),
91
+ )
92
+ # ablation: two gates only
93
+ self.gate_two = nn.Sequential(
94
+ nn.Linear(2, gate_hidden),
95
+ nn.ReLU(),
96
+ nn.Dropout(gate_dropout),
97
+ nn.Linear(gate_hidden, 2),
98
+ )
99
+
100
+ def freeze_clip_all(self):
101
+ self.encoder.freeze_all()
102
+
103
+ def unfreeze_clip_last_blocks(self, n_blocks=2, also_unfreeze_ln=True):
104
+ self.encoder.unfreeze_last_blocks(
105
+ n_blocks=n_blocks,
106
+ also_unfreeze_ln=also_unfreeze_ln,
107
+ )
108
+
109
+ def temporal_pool(self, x):
110
+ """
111
+ x: [B, T, C]
112
+ return: [B, C]
113
+ """
114
+ return x.mean(dim=1)
115
+
116
+ def forward(self, rgb, w_art, w_str, gate_stats=None):
117
+ """
118
+ rgb: [B,3,T,H,W] (rgb in [0,1])
119
+ w_art/w_str: [B,1,T,H,W] (0..1)
120
+ gate_stats: [B,3]
121
+ """
122
+ B, _C, T, H, W = rgb.shape
123
+
124
+ x2d = rgb.permute(0, 2, 1, 3, 4).contiguous().view(B * T, 3, H, W)
125
+ fmap2d = self.encoder(x2d)
126
+ _, C, Hp, Wp = fmap2d.shape # (B*T, 768, 14, 14)
127
+
128
+ w_art_bt = w_art.transpose(1, 2) # (B, T, 1, 14, 14)
129
+ w_str_bt = w_str.transpose(1, 2)
130
+ fmap_bt = fmap2d.view(B, T, C, Hp, Wp) # (B, T, 768, 14, 14)
131
+
132
+ z_art = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_art_bt[:, i]) for i in range(T)], dim=1) # (B, T, 768)
133
+ z_str = torch.stack([weighted_pool_2d(fmap_bt[:, i], w_str_bt[:, i]) for i in range(T)], dim=1)
134
+ z_raw = fmap_bt.mean(dim=(-2, -1))
135
+
136
+ h_art = self.temporal_pool(z_art) # [B, C]
137
+ h_str = self.temporal_pool(z_str)
138
+ h_fmap = self.temporal_pool(z_raw)
139
+
140
+ q_art = self.head_art(h_art)
141
+ q_str = self.head_str(h_str)
142
+ q_fmap = self.head_fmap(h_fmap)
143
+
144
+ mode = self.ablation_mode.lower()
145
+ if mode == "art":
146
+ y_hat = q_art
147
+ weights = torch.tensor([1.0, 0.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
148
+ elif mode == "str":
149
+ y_hat = q_str
150
+ weights = torch.tensor([0.0, 1.0, 0.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
151
+ elif mode == "raw":
152
+ y_hat = q_fmap
153
+ weights = torch.tensor([0.0, 0.0, 1.0], device=q_art.device).unsqueeze(0).repeat(B, 1)
154
+ elif mode == "art+str":
155
+ if gate_stats is None:
156
+ gate_stats = two_gate_stats(w_art, w_str) # [B, 2]
157
+ g_ar = self.gate_two(gate_stats) # [B, 2]
158
+ a, b_ = torch.softmax(g_ar, dim=1).split(1, dim=1)
159
+ y_hat = a * q_art + b_ * q_str
160
+ zero = torch.zeros_like(a)
161
+ weights = torch.cat([a, b_, zero], dim=1)
162
+ elif mode == "full":
163
+ if gate_stats is None:
164
+ gate_stats = default_gate_stats(w_art, w_str, fmap_bt)
165
+ g = self.gate(gate_stats)
166
+ a, b_, c_ = torch.softmax(g, dim=1).split(1, dim=1)
167
+ y_hat = a * q_art + b_ * q_str + c_ * q_fmap
168
+ weights = torch.cat([a, b_, c_], dim=1)
169
+ else:
170
+ raise ValueError(f"Unknown ablation_mode: {self.ablation_mode}")
171
+
172
+ aux = (
173
+ q_art.squeeze(1),
174
+ q_str.squeeze(1),
175
+ q_fmap.squeeze(1),
176
+ weights,
177
+ )
178
+ return y_hat.squeeze(1), aux
src/module/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/module/__init__.py
src/module/compute_weight_map.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import csv
3
+ from pathlib import Path
4
+ import os
5
+ os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
6
+ from decord import VideoReader, cpu
7
+
8
+ from module.frequency_dct import compute_twostream_dct
9
+ from module.read_frame_decord import (
10
+ sample_frames_uniform,
11
+ collect_needed,
12
+ cache_needed_frames,
13
+ read_window_from_cache,
14
+ )
15
+
16
+ def process_video(
17
+ vr,
18
+ # resize
19
+ size=224,
20
+ # anchors + window
21
+ num_anchors=16,
22
+ win=6,
23
+ win_step=1,
24
+ # DCT / weights
25
+ block=16,
26
+ ):
27
+ """
28
+ Returns (frame_all, w_art_all, w_str_all, anchors_kept).
29
+ frame_all: anchor frames (RGB uint8, HxWx3)
30
+ w_art_all/w_str_all: maps (float32, HxW)
31
+ anchors_kept: frame indices used per anchor - list[list[int]]
32
+ """
33
+ total_frames = len(vr)
34
+ if total_frames <= 1:
35
+ raise RuntimeError(f"Video too short / invalid frame count: {total_frames}")
36
+
37
+ anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
38
+ needed = collect_needed(anchor_idxs, total_frames, win, win_step)
39
+ cache = cache_needed_frames(vr, needed, size)
40
+
41
+ frame_all, w_art_all, w_str_all = [], [], []
42
+ anchors_kept = []
43
+
44
+ for anchor in anchor_idxs:
45
+ out = read_window_from_cache(cache, anchor, total_frames, win, win_step)
46
+ if out is None:
47
+ continue
48
+
49
+ anchor_frame, gray_seq, idxs = out
50
+
51
+ w_art, w_str, _dbg = compute_twostream_dct(
52
+ gray_seq,
53
+ block=block,
54
+ )
55
+
56
+ frame_all.append(anchor_frame)
57
+ w_art_all.append(w_art.astype(np.float32, copy=False))
58
+ w_str_all.append(w_str.astype(np.float32, copy=False))
59
+ anchors_kept.append([int(x) for x in idxs])
60
+
61
+ return frame_all, w_art_all, w_str_all, anchors_kept
62
+
63
+
64
+ def compute_weight_map(frame_all, w_art_all, w_str_all):
65
+ if len(frame_all) == 0:
66
+ raise ValueError("No frames produced.")
67
+ if not (len(frame_all) == len(w_art_all) == len(w_str_all)):
68
+ raise ValueError(
69
+ f"Length mismatch: frames={len(frame_all)}, w_art_all={len(w_art_all)}, w_str_all={len(w_str_all)}"
70
+ )
71
+ frames_np = np.stack(frame_all, axis=0) # (N,H,W,3) uint8
72
+ w_art_np = np.stack(w_art_all, axis=0) # (N,H,W) float32
73
+ w_str_np = np.stack(w_str_all, axis=0) # (N,H,W) float32
74
+ return frames_np, w_art_np, w_str_np
75
+
76
+ def read_vid_mos_csv(csv_path):
77
+ rows = []
78
+ with open(csv_path, "r", encoding="utf-8") as f:
79
+ reader = csv.DictReader(f)
80
+ if not reader.fieldnames:
81
+ raise RuntimeError("CSV has no header")
82
+ for r in reader:
83
+ vid = str(r["vid"]).strip()
84
+ mos = float(r["mos"])
85
+ rows.append((vid, mos))
86
+ return rows
87
+
88
+ def rng_rows(rows, seed=0):
89
+ rng = np.random.default_rng(int(seed))
90
+ idx = np.arange(len(rows))
91
+ rng.shuffle(idx)
92
+ train = [rows[i] for i in idx[:]]
93
+ return train
94
+
95
+ if __name__ == "__main__":
96
+ csv_path = "/home/xinyi/Project/FD-VQA/metadata/TEST_metadata.csv"
97
+ db_path = "/home/xinyi/Project/FD-VQA/test_videos/"
98
+ # video_path = "/home/xinyi/Project/FD-VQA/test_videos/NesAirFortressIn4108.37ByTool23.mp4"
99
+
100
+ rows = read_vid_mos_csv(str(csv_path))
101
+ train = rng_rows(rows)
102
+ for i in range(len(train)):
103
+ vid, mos = train[i]
104
+ print(vid, mos)
105
+ # get video path
106
+ base_path = Path(db_path) / vid
107
+ video_path = None
108
+ for ext in ("mp4", "avi", "mkv"):
109
+ p = Path(str(base_path) + f".{ext}")
110
+ if p.exists():
111
+ video_path = str(p)
112
+ break
113
+ if video_path is None:
114
+ raise FileNotFoundError(f"Cannot find {vid} video")
115
+
116
+ try:
117
+ # read video
118
+ vr = VideoReader(video_path, ctx=cpu(0))
119
+ frame_all, w_art_all, w_str_all, anchors_kept = process_video(
120
+ vr,
121
+ size=224,
122
+ num_anchors=16,
123
+ win=6,
124
+ win_step=1,
125
+ block=16,
126
+ )
127
+ frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all)
128
+ print("frames_np:", frames_np.shape, frames_np.dtype)
129
+ print("w_art_np:", w_art_np.shape, w_art_np.dtype)
130
+ print("w_str_np:", w_str_np.shape, w_str_np.dtype)
131
+ print("anchors_kept:", len(anchors_kept), "example:", anchors_kept)
132
+ except Exception as e:
133
+ print("\n[DATA ERROR]")
134
+ print("idx:", i)
135
+ print("vid:", vid)
136
+ print("path:", video_path)
137
+ raise
138
+ finally:
139
+ # release decord video reader
140
+ try:
141
+ if vr is not None:
142
+ del vr
143
+ except Exception:
144
+ pass
src/module/frequency_dct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
3
+ from decord import VideoReader, cpu
4
+ import cv2
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+
12
+ from module.read_frame_decord import sample_frames_uniform, collect_needed, cache_needed_frames, read_window_from_cache
13
+
14
+ # ----------------------------
15
+ # utils
16
+ # ----------------------------
17
+ def norm01(x, eps=1e-6):
18
+ x = x.astype(np.float32)
19
+ mn, mx = float(x.min()), float(x.max())
20
+ return (x - mn) / (mx - mn + eps)
21
+
22
+ # def upsample_grid(grid, out_hw, interp=cv2.INTER_LINEAR):
23
+ def upsample_grid(grid, out_hw, interp=cv2.INTER_NEAREST):
24
+ H, W = out_hw
25
+ return cv2.resize(grid.astype(np.float32), (W, H), interpolation=interp)
26
+
27
+ def gaussian_blur(x, sigma=1.0):
28
+ if sigma <= 0:
29
+ return x
30
+ ksize = int(6 * sigma + 1)
31
+ if ksize % 2 == 0:
32
+ ksize += 1
33
+ if ksize < 3:
34
+ ksize = 3
35
+ return cv2.GaussianBlur(x, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT_101)
36
+
37
+ # sobel operator
38
+ def gradient_mag(gray):
39
+ gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
40
+ gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
41
+ dst = np.sqrt(gx * gx + gy * gy)
42
+ # plt.imshow(dst, cmap='gray')
43
+ # plt.savefig('/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only/Sobel_operator_result.jpg', dpi=300)
44
+ return dst
45
+
46
+ # Sobel magnitude -> normalize -> threshold -> dilate edge region
47
+ def edge_sobel(gray, thr=0.20, dilate_px=2):
48
+ g = gradient_mag(gray).astype(np.float32)
49
+ g = norm01(g) # normalize
50
+ edge = (g >= thr).astype(np.uint8)
51
+ if dilate_px > 0:
52
+ k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * dilate_px + 1, 2 * dilate_px + 1))
53
+ edge = cv2.dilate(edge, k, iterations=1)
54
+ # plt.imshow(edge, cmap='gray')
55
+ # plt.savefig('/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only/edge_result.jpg', dpi=300)
56
+ return edge
57
+
58
+ # per-block fraction of near edge pixels
59
+ def block_edge(edge_mask, block=16):
60
+ h, w = edge_mask.shape
61
+ gh, gw = h // block, w // block
62
+ frac = np.zeros((gh, gw), dtype=np.float32)
63
+ for by in range(gh):
64
+ for bx in range(gw):
65
+ y0, x0 = by * block, bx * block
66
+ frac[by, bx] = float(edge_mask[y0:y0 + block, x0:x0 + block].mean())
67
+ # plt.imshow(frac, cmap='gray')
68
+ # plt.savefig('/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only/edge_block.jpg', dpi=300)
69
+ return frac
70
+
71
+ # per-block mean
72
+ def block_mean(img, block=16):
73
+ h, w = img.shape
74
+ gh, gw = h // block, w // block
75
+ x = img[:gh*block, :gw*block]
76
+ x = x.reshape(gh, block, gw, block) # (block_row, inblock_row, block_col, inblock_col)
77
+ return x.mean(axis=(1, 3)).astype(np.float32)
78
+
79
+ # discontinuities across block boundaries: |I[:, x] - I[:, x-1]| & |I[y, :] - I[y-1, :]|
80
+ def blockiness_boundary_map(gray, block_size=16, blur_sigma=1.0):
81
+ h, w = gray.shape
82
+ v = np.zeros((h, w), dtype=np.float32)
83
+ hmap = np.zeros((h, w), dtype=np.float32)
84
+
85
+ for x in range(block_size, w, block_size):
86
+ v[:, x] = np.abs(gray[:, x] - gray[:, x - 1])
87
+ for y in range(block_size, h, block_size):
88
+ hmap[y, :] = np.abs(gray[y, :] - gray[y - 1, :])
89
+ b = v + hmap
90
+ if blur_sigma > 0:
91
+ b = gaussian_blur(b, blur_sigma)
92
+ return b
93
+
94
+ # ----------------------------
95
+ # DCT 16x16 energies
96
+ # ----------------------------
97
+ def dct_energy_ratios(gray, block=16, low_k=4, mid_k=8, eps=1e-6):
98
+ """
99
+ Return per-block energy ratios for frequency bands.
100
+ low: [0:low_k, 0:low_k]
101
+ mid: [0:mid_k, 0:mid_k] - low
102
+ high: remaining
103
+ """
104
+ H, W = gray.shape
105
+ gh, gw = H // block, W // block
106
+ r_low = np.zeros((gh, gw), dtype=np.float32)
107
+ r_mid = np.zeros((gh, gw), dtype=np.float32)
108
+ r_high = np.zeros((gh, gw), dtype=np.float32)
109
+
110
+ for by in range(gh):
111
+ for bx in range(gw):
112
+ y0, x0 = by * block, bx * block
113
+ patch = gray[y0:y0 + block, x0:x0 + block].astype(np.float32)
114
+
115
+ C = cv2.dct(patch)
116
+ E = C * C
117
+ E_low = E[:low_k, :low_k].sum() # 4 x 4
118
+ E_mid = E[:mid_k, :mid_k].sum() - E_low # 8×8
119
+ E_total = E.sum()
120
+ E_high = max(E_total - (E_low + E_mid), 0.0)
121
+
122
+ denom = E_total + eps
123
+ r_low[by, bx] = E_low / denom
124
+ r_mid[by, bx] = E_mid / denom
125
+ r_high[by, bx] = E_high / denom
126
+ return r_low, r_mid, r_high
127
+
128
+ # Temporal map via FFT
129
+ def temporal_fft_map(gray_seq, *, block=16, hf_start_bin=2, eps=1e-6):
130
+ # each frame -> (K, gh, gw) block mean grid
131
+ grids = [block_mean(g.astype(np.float32), block=block) for g in gray_seq]
132
+ s = np.stack(grids, axis=0)
133
+
134
+ # rFFT along time
135
+ X = np.fft.rfft(s, axis=0) # (F, gh, gw)
136
+ E = (X.real * X.real + X.imag * X.imag).astype(np.float32) # power spectrum
137
+
138
+ # F = K // 2 + 1
139
+ dc = E[0]
140
+ if E.shape[0] <= 1:
141
+ z = np.zeros_like(dc, dtype=np.float32)
142
+ return z, z
143
+ P = E[1:] # drop DC -> (F-1, gh, gw)
144
+ non_dc = P.sum(axis=0)
145
+
146
+ # changes relative to DC
147
+ motion = non_dc / (dc + eps)
148
+
149
+ # flicker: change is in high temporal freqs
150
+ start = max(int(hf_start_bin) - 1, 0) # index in P
151
+ if start >= P.shape[0]:
152
+ flicker = np.zeros_like(non_dc, dtype=np.float32)
153
+ else:
154
+ hi = P[start:].sum(axis=0)
155
+ flicker = hi / (non_dc + eps)
156
+ return motion.astype(np.float32), flicker.astype(np.float32)
157
+
158
+ def fuse_temporal_maps(motion_grid, flicker_grid, *, beta=0.5):
159
+ m = norm01(motion_grid)
160
+ f = np.clip(flicker_grid, 0.0, 1.0)
161
+ # boosts where flicker is high
162
+ w = m * ((1.0 - beta) + beta * f)
163
+ return norm01(w)
164
+
165
+ # ----------------------------
166
+ # DCT -> two stream weights
167
+ # ----------------------------
168
+ def compute_twostream_dct(
169
+ gray_seq,
170
+ *,
171
+ block=16,
172
+ ):
173
+ K = len(gray_seq)
174
+ gray_anchor = gray_seq[0]
175
+ H, W = gray_anchor.shape
176
+
177
+ r_low_stack, r_mid_stack, r_high_stack = [], [], []
178
+ for g in gray_seq:
179
+ r_low, r_mid, r_high = dct_energy_ratios(g, block=block)
180
+ r_low_stack.append(r_low) # (gh, gw)
181
+ r_mid_stack.append(r_mid)
182
+ r_high_stack.append(r_high)
183
+ r_low_stack = np.stack(r_low_stack, axis=0) # (K, gh, gw)
184
+ r_mid_stack = np.stack(r_mid_stack, axis=0)
185
+ r_high_stack = np.stack(r_high_stack, axis=0)
186
+
187
+ # frequency band (anchor frame)
188
+ anchor_low_grid = r_low_stack[0] # (gh, gw)
189
+ anchor_mid_grid = r_mid_stack[0]
190
+ anchor_high_grid = r_high_stack[0]
191
+
192
+ # Ringing map (anchor)
193
+ edge_mask = edge_sobel(gray_anchor) # around edges
194
+ edge_frac = block_edge(edge_mask, block=block)
195
+ mh_band = r_mid_stack[0] + r_high_stack[0] # mid/high frequency energy
196
+ ring_score = np.maximum(mh_band, 0.0) # score = mid_high - 0 * r_low # alpha = 0
197
+ edge_min_frac = 0.05
198
+ ringing_grid = np.where(edge_frac >= edge_min_frac, edge_frac * ring_score, 0.0).astype(np.float32)
199
+ s = np.percentile(ringing_grid, 99) + 1e-6
200
+ ringing_grid01 = np.clip(ringing_grid / s, 0.0, 1.0)
201
+
202
+ # Blur map (anchor): like low-pass filtering
203
+ hf = 0.5 * r_mid_stack[0] + 1.0 * r_high_stack[0]
204
+ blur_raw = np.clip(1.0 - hf, 0.0, 1.0)
205
+ sobel_g = gradient_mag(gray_anchor).astype(np.float32)
206
+ sobel_g_grid = block_mean(sobel_g, block=block)
207
+ sobel_g_grid = norm01(sobel_g_grid) # soft structure weight
208
+ blur_grid = np.clip(blur_raw * sobel_g_grid, 0.0, 1.0)
209
+
210
+ # Blockiness map (anchor): boundary discontinuities
211
+ boundary_pix = blockiness_boundary_map(gray_anchor, block_size=block)
212
+ blockiness_grid = norm01(block_mean(boundary_pix, block=block))
213
+
214
+ # Temporal (window): FFT along time
215
+ if K >= 4:
216
+ motion_grid, flick_grid = temporal_fft_map(gray_seq, block=block, hf_start_bin=2)
217
+ temporal_grid = fuse_temporal_maps(motion_grid, flick_grid, beta=0.5)
218
+ elif K == 2:
219
+ E_stack = norm01(r_mid_stack + r_high_stack)
220
+ temporal_grid = norm01(np.abs(E_stack[1] - E_stack[0]))
221
+
222
+ # -------------Combine---------------
223
+ w_art = norm01(1.0 * ringing_grid01 + 1.0 * blur_grid + 1.0 * blockiness_grid + 1.0 * temporal_grid)
224
+ w_str = 1.0 - w_art
225
+
226
+ debug = {
227
+ # frequency band (anchor)
228
+ "dct_low_grid": anchor_low_grid,
229
+ "dct_mid_grid": anchor_mid_grid,
230
+ "dct_high_grid": anchor_high_grid,
231
+ # ringing (anchor)
232
+ "ringing_grid": ringing_grid01,
233
+ "edge_px": edge_mask,
234
+ # blur (anchor)
235
+ "blur_grid": blur_grid,
236
+ # blockiness (anchor)
237
+ "blockiness_grid": blockiness_grid,
238
+ # temporal (window)
239
+ "temporal_grid": temporal_grid,
240
+ }
241
+ return w_art, w_str, debug
242
+
243
+ # ----------------------------
244
+ # visualization panel
245
+ # ----------------------------
246
+ def save_panel(out_png, frame_rgb, w_art, w_str, debug):
247
+ fig = plt.figure(figsize=(16, 9), dpi=160)
248
+ def add(ax_i, title, img, cmap=None, vmin=0, vmax=1):
249
+ ax = fig.add_subplot(3, 4, ax_i)
250
+ ax.set_title(title)
251
+ if cmap is None:
252
+ ax.imshow(img)
253
+ else:
254
+ ax.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax)
255
+ ax.axis("off")
256
+
257
+ add(1, "Frame_t (anchor)", frame_rgb)
258
+ add(2, "DCT LOW (grid)", norm01(debug["dct_low_grid"]), cmap="viridis")
259
+ add(3, "DCT MID (grid)", norm01(debug["dct_mid_grid"]), cmap="viridis")
260
+ add(4, "DCT HIGH (grid)", norm01(debug["dct_high_grid"]), cmap="viridis")
261
+ # ringing
262
+ add(5, "EDGE (mask)", debug["edge_px"], cmap="viridis")
263
+ add(6, "RINGING (mid/high, grid)", debug["ringing_grid"], cmap="viridis")
264
+ # blur
265
+ add(7, "BLUR (lowpass, grid)", debug["blur_grid"], cmap="viridis")
266
+ # blockiness
267
+ add(8, "BLOCKINESS (boundary, grid)", debug["blockiness_grid"], cmap="viridis")
268
+ # temporal
269
+ add(9, "TEMPORAL (grid)", debug["temporal_grid"], cmap="viridis")
270
+ # all map
271
+ add(10, "W_art", w_art, cmap="viridis")
272
+ add(11, "W_str", w_str, cmap="viridis")
273
+ os.makedirs(os.path.dirname(out_png), exist_ok=True)
274
+ fig.tight_layout()
275
+ fig.savefig(out_png)
276
+ plt.close(fig)
277
+
278
+
279
+ # ----------------------------
280
+ # Main
281
+ # ----------------------------
282
+ def main():
283
+ parser = argparse.ArgumentParser()
284
+
285
+ parser.add_argument("--video", default="/home/xinyi/Project/FD-VQA/test_videos/SDR_Animal_5ngj.mp4")
286
+ parser.add_argument("--out_dir", default="/home/xinyi/Project/FD-VQA/test_videos/freq_test_dct_only")
287
+ parser.add_argument("--size", type=int, default=224)
288
+
289
+ # fixed-T anchors over whole video (duration-uniform)
290
+ parser.add_argument("--num_anchors", type=int, default=16)
291
+ parser.add_argument("--win", type=int, default=6)
292
+ parser.add_argument("--win_step", type=int, default=1)
293
+ parser.add_argument("--block", type=int, default=16)
294
+
295
+ parser.add_argument("--no_panel", action="store_true")
296
+ args = parser.parse_args()
297
+ os.makedirs(args.out_dir, exist_ok=True)
298
+
299
+ vr = VideoReader(args.video, ctx=cpu(0))
300
+ total_frames = len(vr)
301
+ if total_frames <= 1:
302
+ raise RuntimeError(f"Video too short / frame count unavailable: total_frames={total_frames}")
303
+ print("total_frames:", total_frames)
304
+
305
+ size = args.size
306
+ win = args.win
307
+ win_step = args.win_step
308
+ num_anchors = args.num_anchors
309
+
310
+ anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
311
+ needed = collect_needed(anchor_idxs, total_frames, win, win_step)
312
+ print("anchor_idxs:", anchor_idxs)
313
+ cache = cache_needed_frames(vr, needed, size)
314
+ print("cached:", len(cache), "needed:", len(needed))
315
+
316
+ frame_all, w_art_all, w_str_all = [], [], []
317
+ anchors_kept = []
318
+ image_idx = 0
319
+
320
+ for anchor in tqdm(anchor_idxs, desc="Processing anchors (DCT)"):
321
+ out = read_window_from_cache(cache, anchor, total_frames, win, win_step)
322
+ if out is None:
323
+ continue
324
+ anchor_frame, gray_seq, idxs = out
325
+
326
+ w_art, w_str, dbg = compute_twostream_dct(
327
+ gray_seq,
328
+ block=args.block,
329
+ )
330
+
331
+ frame_all.append(anchor_frame)
332
+ w_art_all.append(w_art)
333
+ w_str_all.append(w_str)
334
+ anchors_kept.append(idxs)
335
+
336
+ image_idx += 1
337
+ if not args.no_panel:
338
+ save_panel(
339
+ os.path.join(args.out_dir, f"anchor_{anchor:03d}_{image_idx:02d}.png"),
340
+ anchor_frame,
341
+ w_art,
342
+ w_str,
343
+ dbg,
344
+ )
345
+
346
+ print(f"Done. Outputs saved to: {args.out_dir}")
347
+ print(anchors_kept)
348
+ print(f"total_frames={total_frames}, num_anchors_target={num_anchors}, anchors_produced={len(w_str_all)}")
349
+
350
+ if __name__ == "__main__":
351
+ main()
src/module/read_frame_decord.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
5
+ from decord import VideoReader, cpu
6
+
7
+ def sample_frames_uniform(total_frames, num_anchors, win=8, win_step=1):
8
+ if total_frames <= 0:
9
+ return [0] * num_anchors
10
+
11
+ min_win = (win - 1) * win_step + 1
12
+ if total_frames < num_anchors + min_win:
13
+ if total_frames >= num_anchors:
14
+ anchor_idxs = np.linspace(0, total_frames - 1, num_anchors).round().astype(int)
15
+ return anchor_idxs.tolist()
16
+ else:
17
+ anchor_idxs = list(range(total_frames))
18
+ anchor_idxs.extend([total_frames - 1] * (num_anchors - total_frames))
19
+ return anchor_idxs
20
+
21
+ max_start = total_frames - min_win
22
+ anchor_idxs = np.linspace(0, max_start, num_anchors).round().astype(int)
23
+ return anchor_idxs.tolist()
24
+
25
+ # window idx for anchor
26
+ def window_indices(anchor, total_frames, win, win_step):
27
+ last = total_frames - 1
28
+ idxs = [anchor + k * win_step for k in range(win)]
29
+ return [i if i < total_frames else last for i in idxs]
30
+
31
+ # collect idx for all windows
32
+ def collect_needed(anchor_idxs, total_frames, win, win_step):
33
+ needed = set()
34
+ for a in anchor_idxs:
35
+ for idx in window_indices(a, total_frames, win, win_step):
36
+ needed.add(idx)
37
+ return needed
38
+
39
+ # read frame for needed, read video in sequence
40
+ def cache_needed_frames(vr, needed, size):
41
+ if not needed:
42
+ return {}
43
+ max_idx = max(needed)
44
+ cache = {}
45
+
46
+ last_ok = -1
47
+ for i in range(max_idx + 1):
48
+ try:
49
+ # read next frame
50
+ frame_rgb = vr.next().asnumpy() # decord frame RGB (H,W,3)
51
+ except StopIteration:
52
+ # video ended early, like cap.read() failed
53
+ print(f"[DEBUG] cap.read() failed at i={i}, last_ok={last_ok}, max_idx={max_idx}")
54
+ print(f"[DEBUG] max_cached={max(cache.keys()) if cache else None}, last_ok={last_ok}")
55
+ break
56
+ except Exception as e:
57
+ # decode error
58
+ print(f"[DEBUG] cap.read() failed at i={i}, last_ok={last_ok}, max_idx={max_idx} | {e}")
59
+ print(f"[DEBUG] max_cached={max(cache.keys()) if cache else None}, last_ok={last_ok}")
60
+ break
61
+
62
+ last_ok = i
63
+ if i in needed:
64
+ # frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) # keep same as OpenCV: BGR
65
+ frame = cv2.resize(frame_rgb, (size, size), interpolation=cv2.INTER_AREA)
66
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY).astype(np.float32)
67
+ cache[i] = (frame, gray)
68
+ return cache
69
+
70
+ # each anchor: read window from cache frames
71
+ def read_window_from_cache(cache, anchor, total_frames, win, win_step):
72
+ idxs = window_indices(anchor, total_frames, win, win_step)
73
+ # print(f"window idxs: {idxs}")
74
+ if not cache:
75
+ return None
76
+
77
+ # fallback = last frame
78
+ last_avail = max(cache.keys())
79
+ fallback = cache[last_avail]
80
+
81
+ anchor_frame = None
82
+ gray_seq = []
83
+ for j, idx in enumerate(idxs):
84
+ if idx in cache:
85
+ data = cache[idx]
86
+ fallback = data
87
+ else:
88
+ data = fallback
89
+
90
+ frame, gray = data
91
+ if j == 0: # window [0] is anchor
92
+ anchor_frame = frame
93
+ gray_seq.append(gray)
94
+
95
+ if len(gray_seq) == 1:
96
+ gray_seq.append(gray_seq[0])
97
+ idxs.append(idxs[0])
98
+ return anchor_frame, gray_seq, idxs
99
+
100
+
101
+ if __name__ == "__main__":
102
+ video_path = "/home/xinyi/Project/FD-VQA/test_videos/3369925072.mp4"
103
+ print("video:", video_path)
104
+
105
+ vr = VideoReader(video_path, ctx=cpu(0))
106
+ total_frames = len(vr)
107
+ if total_frames <= 1:
108
+ raise RuntimeError(f"Video too short / frame count unavailable: total_frames={total_frames}")
109
+ print("total_frames:", total_frames)
110
+
111
+ size = 224
112
+ win = 8
113
+ win_step = 1
114
+ num_anchors = 16
115
+
116
+ anchor_idxs = sample_frames_uniform(total_frames, num_anchors, win=win, win_step=win_step)
117
+ needed = collect_needed(anchor_idxs, total_frames, win, win_step)
118
+ print("anchor_idxs:", anchor_idxs)
119
+ print("needed:", needed)
120
+
121
+ cache = cache_needed_frames(vr, needed, size)
122
+ print("cached:", len(cache), "needed:", len(needed))
123
+
124
+ frames_list, grays_list, processed_list = [], [], []
125
+ for a in anchor_idxs:
126
+ out = read_window_from_cache(cache, a, total_frames, win, win_step)
127
+ if out is None:
128
+ continue
129
+ anchor_frame, gray_seq, idxs = out
130
+ frames_list.append(anchor_frame)
131
+ grays_list.append(gray_seq)
132
+ processed_list.append(idxs)
133
+
134
+
135
+ print("frames:", len(frames_list))
136
+ print("one frame shape:", frames_list[0].shape)
137
+ print("each window length:", len(grays_list[0]) if grays_list else 0)
138
+ print("window idxs:", processed_list if processed_list else None)
src/train.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
3
+ import argparse
4
+ import csv
5
+ from pathlib import Path
6
+ import numpy as np
7
+ from scipy.stats import pearsonr as _pearsonr_scipy, spearmanr as _spearmanr_scipy
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import cv2
11
+ from decord import VideoReader, cpu
12
+ from torch.amp import GradScaler, autocast
13
+ from tqdm import tqdm
14
+
15
+ from module.compute_weight_map import process_video, compute_weight_map
16
+ from model.qd_model import QD_MODEL
17
+
18
+
19
+ # ----------------------------
20
+ # Data utils
21
+ # ----------------------------
22
+ def read_vid_mos_csv(csv_path):
23
+ rows = []
24
+ with open(csv_path, "r", encoding="utf-8") as f:
25
+ reader = csv.DictReader(f)
26
+ if not reader.fieldnames:
27
+ raise RuntimeError("CSV has no header")
28
+ for r in reader:
29
+ vid = str(r["vid"]).strip()
30
+ mos = float(r["mos"])
31
+ rows.append((vid, mos))
32
+ return rows
33
+
34
+ def split_rows(rows, seed=42, test_ratio=0.2, val_ratio=0.1):
35
+ rng = np.random.default_rng(int(seed))
36
+ idx = np.arange(len(rows))
37
+ rng.shuffle(idx)
38
+
39
+ n = len(rows)
40
+ n_test = int(round(n * test_ratio))
41
+ n_train_all = n - n_test # train+val
42
+ n_val = int(round(n_train_all * val_ratio)) # val from train_all
43
+
44
+ val = [rows[i] for i in idx[:n_val]]
45
+ train = [rows[i] for i in idx[n_val:n_train_all]]
46
+ test = [rows[i] for i in idx[n_train_all:]]
47
+ return train, val, test
48
+
49
+ def split_train_val(rows, seed=42, val_ratio=0.1):
50
+ rng = np.random.default_rng(int(seed))
51
+ idx = np.arange(len(rows))
52
+ rng.shuffle(idx)
53
+
54
+ n = len(rows)
55
+ n_val = int(round(n * val_ratio))
56
+ val = [rows[i] for i in idx[:n_val]]
57
+ train = [rows[i] for i in idx[n_val:]]
58
+ return train, val
59
+
60
+ def pearsonr(x, y, eps=1e-12):
61
+ # PLCC (SciPy): returns a torch scalar tensor so call-site ".item()" still works
62
+ if hasattr(x, "detach"):
63
+ x = x.detach().cpu().numpy()
64
+ if hasattr(y, "detach"):
65
+ y = y.detach().cpu().numpy()
66
+ x = np.asarray(x).reshape(-1)
67
+ y = np.asarray(y).reshape(-1)
68
+
69
+ # avoid NaN when constant / too short
70
+ if x.size < 2 or np.std(x) < eps or np.std(y) < eps:
71
+ return torch.tensor(0.0)
72
+ r, _p = _pearsonr_scipy(x, y)
73
+ if np.isnan(r):
74
+ r = 0.0
75
+ return torch.tensor(float(r))
76
+
77
+ def spearmanr(x, y, eps=1e-12):
78
+ # SRCC (SciPy): handles ties correctly; returns torch scalar tensor
79
+ if hasattr(x, "detach"):
80
+ x = x.detach().cpu().numpy()
81
+ if hasattr(y, "detach"):
82
+ y = y.detach().cpu().numpy()
83
+
84
+ x = np.asarray(x).reshape(-1)
85
+ y = np.asarray(y).reshape(-1)
86
+ if x.size < 2 or np.std(x) < eps or np.std(y) < eps:
87
+ return torch.tensor(0.0)
88
+
89
+ r, _p = _spearmanr_scipy(x, y)
90
+ if np.isnan(r):
91
+ r = 0.0
92
+ return torch.tensor(float(r))
93
+
94
+ # ----------------------------
95
+ # Train utils
96
+ # ----------------------------
97
+ def build_scheduler(optim, args):
98
+ warm = int(args.warmup_epochs)
99
+ total = int(args.epochs)
100
+ warm = max(0, min(warm, total - 1))
101
+
102
+ # warmup
103
+ warmup = torch.optim.lr_scheduler.LinearLR(
104
+ optim,
105
+ start_factor=0.1,
106
+ total_iters=warm if warm > 0 else 1,
107
+ )
108
+ # cosine warmup
109
+ cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
110
+ optim,
111
+ T_max=(total - warm) if (total - warm) > 0 else 1,
112
+ eta_min=float(args.min_lr),
113
+ )
114
+
115
+ if warm > 0:
116
+ return torch.optim.lr_scheduler.SequentialLR(
117
+ optim,
118
+ schedulers=[warmup, cosine],
119
+ milestones=[warm],
120
+ )
121
+ return cosine
122
+
123
+ def com_loss(y_pred, y_true, reg_w=0.6, rank_w=1.0, huber_beta=1.0, margin=0.0):
124
+ # 1) Huber / SmoothL1
125
+ if huber_beta is None:
126
+ reg_loss = F.l1_loss(y_pred, y_true, reduction="mean")
127
+ else:
128
+ reg_loss = F.smooth_l1_loss(y_pred, y_true, beta=float(huber_beta), reduction="mean")
129
+ reg_loss = reg_loss * float(reg_w)
130
+
131
+ # 2) pairwise hinge rank
132
+ B = y_true.shape[0]
133
+ if B < 2 or float(rank_w) == 0.0:
134
+ rank_loss = y_pred.new_tensor(0.0)
135
+ return reg_loss + rank_loss, reg_loss, rank_loss
136
+ pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0) # [B,B]
137
+ true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0) # [B,B]
138
+ s = torch.sign(true_diff) # -1,0,+1
139
+
140
+ # y_true = pair, ignore
141
+ mask = (s != 0).float()
142
+ # hinge: max(0, margin - s*(pred_i - pred_j))
143
+ rank_mat = F.relu(float(margin) - s * pred_diff) * mask
144
+ denom = mask.sum().clamp_min(1.0)
145
+ rank_loss = rank_mat.sum() / denom
146
+ rank_loss = rank_loss * float(rank_w)
147
+
148
+ total = reg_loss + rank_loss
149
+ return total, reg_loss, rank_loss
150
+
151
+ # ----------------------------
152
+ # Dataset
153
+ # ----------------------------
154
+ class VQADataset(torch.utils.data.Dataset):
155
+ """
156
+ Returns per item:
157
+ rgb: [3, T, H, W] float in [0,1] (RGB)
158
+ w_art: [1, T, H, W] float in [0,1]
159
+ w_str: [1, T, H, W] float in [0,1]
160
+ y: scalar float (MOS, optional normalized)
161
+ vid: str
162
+ """
163
+ def __init__(self, rows, db_path, clip_len, size, win, win_step, mos_mean=None, mos_std=None):
164
+ self.rows = rows
165
+ self.db_path = str(db_path)
166
+ self.clip_len = int(clip_len)
167
+ self.size = int(size)
168
+ self.win = int(win)
169
+ self.win_step = int(win_step)
170
+ self.mos_mean = mos_mean
171
+ self.mos_std = mos_std
172
+
173
+ def __len__(self):
174
+ return len(self.rows)
175
+
176
+ def __getitem__(self, idx):
177
+ vid, mos = self.rows[int(idx)]
178
+ num_anchors = self.clip_len
179
+ size = self.size
180
+ win = self.win
181
+ win_step = self.win_step
182
+
183
+ # get video path
184
+ base_path = Path(self.db_path) / vid
185
+ video_path = None
186
+ for ext in ("mp4", "avi", "mkv"):
187
+ p = Path(str(base_path) + f".{ext}")
188
+ if p.exists():
189
+ video_path = str(p)
190
+ break
191
+ if video_path is None:
192
+ raise FileNotFoundError(f"Cannot find {vid} video")
193
+
194
+ try:
195
+ # read video
196
+ vr = VideoReader(video_path, ctx=cpu(0))
197
+ frame_all, w_art_all, w_str_all, anchors_kept = process_video(
198
+ vr,
199
+ size=size, num_anchors=num_anchors, win=win, win_step=win_step,
200
+ )
201
+ frames_np, w_art_np, w_str_np = compute_weight_map(frame_all, w_art_all, w_str_all)
202
+ # print("frames_np:", frames_np.shape, frames_np.dtype)
203
+ # print("w_art_np:", w_art_np.shape, w_art_np.dtype)
204
+ # print("w_str_np:", w_str_np.shape, w_str_np.dtype)
205
+ # print("anchors_kept:", len(anchors_kept), "example:", anchors_kept[0])
206
+ except Exception as e:
207
+ print("\n[DATA ERROR]")
208
+ print("idx:", idx)
209
+ print("vid:", vid)
210
+ raise
211
+ finally:
212
+ # release decord video reader
213
+ try:
214
+ if vr is not None:
215
+ del vr
216
+ except Exception:
217
+ pass
218
+
219
+ # fixed length sampling
220
+ T = self.clip_len
221
+
222
+ # frames_sel = [cv2.cvtColor(frames_np[i], cv2.COLOR_BGR2RGB) for i in range(T)] # frame BGR -> RGB: [T,H,W,3] -> [3,T,H,W]
223
+ frames_sel = [frames_np[i] for i in range(T)] # RGB frames_np
224
+ rgb = torch.from_numpy(np.stack(frames_sel, axis=0)).float()
225
+ rgb = rgb.permute(3, 0, 1, 2).contiguous() / 255.0
226
+
227
+ # W_art / W_str: [T,H,W] -> [1,T,H,W]
228
+ w_art = torch.from_numpy(np.stack([w_art_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float()
229
+ w_str = torch.from_numpy(np.stack([w_str_np[i] for i in range(T)], axis=0).astype(np.float32)).unsqueeze(0).float()
230
+
231
+ # MOS
232
+ y = float(mos)
233
+ if self.mos_mean is not None and self.mos_std is not None:
234
+ y = (y - self.mos_mean) / (self.mos_std + 1e-8)
235
+ y = torch.tensor(y).float()
236
+ return rgb, w_art, w_str, y, str(vid)
237
+
238
+
239
+ # ----------------------------
240
+ # Train
241
+ # ----------------------------
242
+ @torch.no_grad()
243
+ def _gather_cat(xs):
244
+ if not xs:
245
+ return torch.empty(0)
246
+ return torch.cat(xs, dim=0)
247
+
248
+ def run_epoch(model, loader, device, *, optim=None, amp=True, mos_mean=None, mos_std=None, desc="", show_pbar=True, log_interval=10):
249
+ is_train = optim is not None
250
+ model.train(is_train)
251
+
252
+ scaler = getattr(run_epoch, "_scaler", None)
253
+ if scaler is None:
254
+ device_type = "cuda" if str(device).startswith("cuda") else "cpu"
255
+ run_epoch._scaler = GradScaler(device_type, enabled=(amp and device_type == "cuda"))
256
+ scaler = run_epoch._scaler
257
+
258
+ losses = []
259
+ y_all = []
260
+ yhat_all = []
261
+
262
+ # ---- tqdm progress bar ----
263
+ it = loader
264
+ if show_pbar:
265
+ it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True)
266
+
267
+ for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1):
268
+ rgb = rgb.to(device, non_blocking=True) # [B,3,T,H,W]
269
+ w_art = w_art.to(device, non_blocking=True) # [B,1,T,H,W]
270
+ w_str = w_str.to(device, non_blocking=True) # [B,1,T,H,W]
271
+ y = y.to(device, non_blocking=True).float() # [B]
272
+
273
+ if is_train:
274
+ optim.zero_grad(set_to_none=True)
275
+
276
+ device_type = "cuda" if str(device).startswith("cuda") else "cpu"
277
+ if is_train:
278
+ with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
279
+ yhat, _aux = model(rgb, w_art, w_str) # yhat: [B]
280
+ loss, loss_reg, loss_rank = com_loss(yhat, y)
281
+
282
+ scaler.scale(loss).backward()
283
+ scaler.step(optim)
284
+ scaler.update()
285
+ else:
286
+ with torch.inference_mode():
287
+ with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
288
+ yhat, _aux = model(rgb, w_art, w_str)
289
+ loss, loss_reg, loss_rank = com_loss(yhat, y)
290
+
291
+ loss_cpu = loss.detach().float().cpu()
292
+ losses.append(loss_cpu)
293
+ y_all.append(y.detach().float().cpu())
294
+ yhat_all.append(yhat.detach().float().cpu())
295
+
296
+ # ---- update bar every log_interval steps ----
297
+ if show_pbar and (step % int(log_interval) == 0 or step == len(loader)):
298
+ avg_loss_so_far = torch.stack(losses).mean().item()
299
+ lrs = None
300
+ if is_train and hasattr(optim, "param_groups") and optim.param_groups:
301
+ lrs = [pg.get("lr", None) for pg in optim.param_groups]
302
+
303
+ postfix = {"loss": f"{avg_loss_so_far:.4f}"}
304
+ if lrs is not None:
305
+ postfix["lrs"] = ",".join([f"{x:.2e}" for x in lrs if x is not None])
306
+ it.set_postfix(postfix)
307
+
308
+ y_all = _gather_cat(y_all)
309
+ yhat_all = _gather_cat(yhat_all)
310
+
311
+ # ---- 反标准化:在 MOS 原尺度上算相关系数 ----
312
+ if mos_mean is not None and mos_std is not None:
313
+ y_all = y_all * mos_std + mos_mean
314
+ yhat_all = yhat_all * mos_std + mos_mean
315
+
316
+ plcc = pearsonr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0
317
+ srcc = spearmanr(y_all, yhat_all).item() if y_all.numel() > 1 else 0.0
318
+ rmse = torch.sqrt(torch.mean((yhat_all - y_all) ** 2)).item() if y_all.numel() > 0 else 0.0
319
+
320
+ avg_loss = torch.stack(losses).mean().item() if losses else 0.0
321
+ return avg_loss, plcc, srcc, rmse
322
+
323
+
324
+ def main():
325
+ ap = argparse.ArgumentParser()
326
+ # ----- data -----
327
+ ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/LSVQ_TRAIN_metadata.csv")
328
+ ap.add_argument("--db_path", default="/media/xinyi/server/LSVQ/")
329
+ ap.add_argument("--split_seed", type=int, default=42)
330
+ ap.add_argument("--test_ratio", type=float, default=0.2)
331
+ ap.add_argument("--val_ratio", type=float, default=0.1) # train 80%
332
+ # ----- video processing -----
333
+ ap.add_argument("--clip_len", type=int, default=16)
334
+ ap.add_argument("--resize", type=int, default=224)
335
+ ap.add_argument("--win", type=int, default=6)
336
+ ap.add_argument("--win_step", type=int, default=1)
337
+ # ----- runtime -----
338
+ ap.add_argument("--batch_size", type=int, default=8)
339
+ ap.add_argument("--num_workers", type=int, default=2)
340
+ ap.add_argument("--device", type=str, default="cuda")
341
+ ap.add_argument("--no_amp", action="store_true")
342
+ # ----- hyperparams -----
343
+ ap.add_argument("--epochs", type=int, default=35)
344
+ ap.add_argument("--warmup_epochs", type=int, default=3)
345
+ ap.add_argument("--lr", type=float, default=1e-5)
346
+ ap.add_argument("--min_lr", type=float, default=1e-6)
347
+ ap.add_argument("--finetune_lr", type=float, default=5e-5)
348
+ ap.add_argument("--weight_decay", type=float, default=1e-2)
349
+ ap.add_argument("--clip_unfreeze_blocks", type=int, default=4)
350
+ ap.add_argument("--finetune_last_stage", action="store_true")
351
+ ap.add_argument("--patience", type=int, default=6)
352
+ # ----- save -----
353
+ ap.add_argument("--save_dir", type=str, default="checkpoints")
354
+ ap.add_argument("--save_name", type=str, default="qd_model.pt")
355
+
356
+ args = ap.parse_args()
357
+ torch.manual_seed(args.split_seed)
358
+ device = torch.device(args.device)
359
+ amp = not bool(args.no_amp)
360
+
361
+ # ----------------------------
362
+ # Load rows and split
363
+ # ----------------------------
364
+ csv_path = Path(args.csv_path)
365
+ if csv_path.name == "LSVQ_TRAIN_metadata.csv":
366
+ # LSVQ official split
367
+ test_csv = csv_path.parent / "LSVQ_TEST_metadata.csv"
368
+ if not test_csv.exists():
369
+ raise FileNotFoundError(f"Cannot find LSVQ test csv: {test_csv}")
370
+ train_all = read_vid_mos_csv(str(csv_path))
371
+ test_rows = read_vid_mos_csv(str(test_csv))
372
+ train_rows, val_rows = split_train_val(
373
+ train_all,
374
+ seed=args.split_seed,
375
+ val_ratio=args.val_ratio,
376
+ )
377
+ print(f"[LSVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
378
+ elif csv_path.name == "KVQ_TRAIN_metadata.csv":
379
+ # KVQ challenge split
380
+ val_csv = csv_path.parent / "KVQ_VAL_metadata.csv"
381
+ test_csv = csv_path.parent / "KVQ_TEST_metadata.csv"
382
+ train_rows = read_vid_mos_csv(str(csv_path))
383
+ val_rows = read_vid_mos_csv(str(val_csv))
384
+ test_rows = read_vid_mos_csv(str(test_csv))
385
+ print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
386
+ else:
387
+ # default split for other datasets
388
+ rows = read_vid_mos_csv(str(csv_path))
389
+ train_rows, val_rows, test_rows = split_rows(
390
+ rows,
391
+ seed=args.split_seed,
392
+ test_ratio=args.test_ratio,
393
+ val_ratio=args.val_ratio,
394
+ )
395
+ # print("sizes:", len(rows), len(train_rows), len(val_rows), len(test_rows))
396
+ # print("train first 3:", train_rows[:3])
397
+ # print("val first 3:", val_rows[:3])
398
+ # print("test first 3:", test_rows[:3])
399
+
400
+ # MOS normalization stats from train split
401
+ mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32)
402
+ mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0
403
+ mos_std = float(mos_train.std()) if len(mos_train) else 1.0
404
+ if mos_std <= 1e-8:
405
+ mos_std = 1.0
406
+
407
+ # ----------------------------
408
+ # DB and datasets
409
+ # ----------------------------
410
+ ds_train = VQADataset(
411
+ train_rows, args.db_path,
412
+ clip_len=args.clip_len,
413
+ size=args.resize,
414
+ win=args.win,
415
+ win_step=args.win_step,
416
+ mos_mean=mos_mean,
417
+ mos_std=mos_std,
418
+ )
419
+ ds_val = VQADataset(
420
+ val_rows, args.db_path,
421
+ clip_len=args.clip_len,
422
+ size=args.resize,
423
+ win=args.win,
424
+ win_step=args.win_step,
425
+ mos_mean=mos_mean,
426
+ mos_std=mos_std,
427
+ )
428
+ ds_test = VQADataset(
429
+ test_rows, args.db_path,
430
+ clip_len=args.clip_len,
431
+ size=args.resize,
432
+ win=args.win,
433
+ win_step=args.win_step,
434
+ mos_mean=mos_mean,
435
+ mos_std=mos_std,
436
+ )
437
+
438
+ pin = str(device).startswith("cuda")
439
+
440
+ loader_train = torch.utils.data.DataLoader(
441
+ ds_train,
442
+ batch_size=args.batch_size,
443
+ shuffle=True,
444
+ num_workers=args.num_workers,
445
+ pin_memory=pin,
446
+ # persistent_workers=(args.num_workers > 0),
447
+ prefetch_factor=4 if args.num_workers > 0 else None,
448
+ drop_last=False,
449
+ )
450
+ loader_val = torch.utils.data.DataLoader(
451
+ ds_val,
452
+ batch_size=args.batch_size,
453
+ shuffle=False,
454
+ num_workers=args.num_workers,
455
+ pin_memory=pin,
456
+ drop_last=False,
457
+ )
458
+ loader_test = torch.utils.data.DataLoader(
459
+ ds_test,
460
+ batch_size=args.batch_size,
461
+ shuffle=False,
462
+ num_workers=args.num_workers,
463
+ pin_memory=pin,
464
+ drop_last=False,
465
+ )
466
+
467
+ # ----------------------------
468
+ # Model
469
+ # ----------------------------
470
+ model = QD_MODEL(
471
+ clip_model="openai/clip-vit-base-patch16",
472
+ ).to(device)
473
+
474
+ # Stage A: freeze CLIP
475
+ model.freeze_clip_all()
476
+
477
+ clip_params_all = []
478
+ other_params_all = []
479
+ for name, p in model.named_parameters():
480
+ if name.startswith("encoder."):
481
+ clip_params_all.append(p)
482
+ else:
483
+ other_params_all.append(p)
484
+
485
+ param_groups = []
486
+ if other_params_all:
487
+ param_groups.append({"params": other_params_all, "lr": float(args.lr)})
488
+ if clip_params_all:
489
+ param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)})
490
+
491
+ optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay))
492
+ scheduler = build_scheduler(optim, args)
493
+ # ----------------------------
494
+ # Train loop
495
+ # ----------------------------
496
+ save_dir = Path(args.save_dir)
497
+ save_dir.mkdir(parents=True, exist_ok=True)
498
+ save_path = save_dir / args.save_name
499
+ best_path = save_dir / (Path(args.save_name).stem + ".best.pt")
500
+ best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt")
501
+
502
+ best_val_srcc = -1e18
503
+ bad_epochs = 0
504
+ did_unfreeze = False
505
+
506
+ for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Epochs", dynamic_ncols=True):
507
+
508
+ # Stage B: optional CLIP finetune after warmup
509
+ if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1):
510
+ model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True)
511
+ did_unfreeze = True
512
+
513
+ # 不重建 optim, 保留 Adam 状态,只更新 lr
514
+ if hasattr(optim, "param_groups") and len(optim.param_groups) >= 2:
515
+ optim.param_groups[0]["lr"] = float(args.lr)
516
+ optim.param_groups[1]["lr"] = float(args.finetune_lr)
517
+
518
+ print(
519
+ f"[Stage B] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks | "
520
+ f"lr={float(args.lr)} finetune_lr={float(args.finetune_lr)}"
521
+ )
522
+
523
+ tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch(
524
+ model, loader_train, device,
525
+ optim=optim,
526
+ amp=amp,
527
+ mos_mean=mos_mean,
528
+ mos_std=mos_std,
529
+ desc=f"Train e{epoch:03d}",
530
+ show_pbar=True,
531
+ log_interval=10,
532
+ )
533
+ va_loss, va_plcc, va_srcc, va_rmse = run_epoch(
534
+ model, loader_val, device,
535
+ optim=None,
536
+ amp=amp,
537
+ mos_mean=mos_mean,
538
+ mos_std=mos_std,
539
+ desc=f"Val e{epoch:03d}",
540
+ show_pbar=True,
541
+ log_interval=10,
542
+ )
543
+ scheduler.step()
544
+ print(
545
+ f"epoch {epoch:03d} | "
546
+ f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | "
547
+ f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}"
548
+ )
549
+
550
+ # Save "last" checkpoint every epoch
551
+ ckpt = {
552
+ "epoch": epoch,
553
+ "model": model.state_dict(),
554
+ "optim": optim.state_dict(),
555
+ "mos_mean": mos_mean,
556
+ "mos_std": mos_std,
557
+ "args": vars(args),
558
+ "best_val_srcc": best_val_srcc,
559
+ }
560
+ torch.save(ckpt, str(save_path))
561
+
562
+ # Save best by val SRCC (higher is better)
563
+ if va_srcc > best_val_srcc:
564
+ best_val_srcc = va_srcc
565
+ bad_epochs = 0
566
+ ckpt["best_val_srcc"] = best_val_srcc
567
+ torch.save(ckpt, str(best_path))
568
+ torch.save(model.state_dict(), str(best_weights_path))
569
+ print(
570
+ f" [best] val_srcc={best_val_srcc:.4f} (val_rmse={va_rmse:.4f}) -> saved "
571
+ f"{best_path} and {best_weights_path}"
572
+ )
573
+ else:
574
+ bad_epochs += 1
575
+ if bad_epochs >= int(args.patience):
576
+ print(
577
+ f"[EarlyStop] val_srcc did not improve for {bad_epochs} epochs. "
578
+ f"Stop at epoch {epoch}."
579
+ )
580
+ break
581
+ # ----------------------------
582
+ # Test (load best)
583
+ # ----------------------------
584
+ if best_weights_path.exists():
585
+ sd = torch.load(str(best_weights_path), map_location=device, weights_only=True)
586
+ model.load_state_dict(sd, strict=True)
587
+ print(f"Loaded best weights: {best_weights_path}")
588
+ elif best_path.exists():
589
+ best = torch.load(str(best_path), map_location=device)
590
+ model.load_state_dict(best["model"], strict=True)
591
+ print(f"Loaded best checkpoint: {best_path} (val_srcc={best.get('best_val_srcc', None)})")
592
+
593
+ te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
594
+ model, loader_test, device,
595
+ optim=None,
596
+ amp=amp,
597
+ mos_mean=mos_mean,
598
+ mos_std=mos_std,
599
+ )
600
+ print(f"TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
601
+
602
+
603
+ if __name__ == "__main__":
604
+ main()
src/transfer.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import torch
5
+
6
+ from train import (
7
+ read_vid_mos_csv,
8
+ split_rows,
9
+ VQADataset,
10
+ run_epoch,
11
+ build_scheduler,
12
+ )
13
+ from model.qd_model import QD_MODEL
14
+
15
+
16
+ def load_pretrained_weights(model, pretrained_path, device):
17
+ p = Path(pretrained_path)
18
+ obj = torch.load(str(p), map_location=device, weights_only=True)
19
+ # *.pt: dict checkpoint
20
+ if isinstance(obj, dict) and "model" in obj:
21
+ model.load_state_dict(obj["model"], strict=True)
22
+ return obj
23
+ # best_weights.pt: state_dict
24
+ else:
25
+ model.load_state_dict(obj, strict=True)
26
+ return None
27
+
28
+ def make_loaders(rows_train, rows_val, rows_test, args, mos_mean, mos_std, device):
29
+ ds_train = VQADataset(
30
+ rows_train, args.db_path,
31
+ clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
32
+ mos_mean=mos_mean, mos_std=mos_std,
33
+ )
34
+ ds_val = VQADataset(
35
+ rows_val, args.db_path,
36
+ clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
37
+ mos_mean=mos_mean, mos_std=mos_std,
38
+ )
39
+ ds_test = VQADataset(
40
+ rows_test, args.db_path,
41
+ clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
42
+ mos_mean=mos_mean, mos_std=mos_std,
43
+ )
44
+
45
+ pin = str(device).startswith("cuda")
46
+ loader_train = torch.utils.data.DataLoader(
47
+ ds_train, batch_size=args.batch_size, shuffle=True,
48
+ num_workers=args.num_workers, pin_memory=pin,
49
+ persistent_workers=(args.num_workers > 0),
50
+ prefetch_factor=4 if args.num_workers > 0 else None,
51
+ drop_last=False,
52
+ )
53
+ loader_val = torch.utils.data.DataLoader(
54
+ ds_val, batch_size=args.batch_size, shuffle=False,
55
+ num_workers=args.num_workers, pin_memory=pin, drop_last=False,
56
+ )
57
+ loader_test = torch.utils.data.DataLoader(
58
+ ds_test, batch_size=args.batch_size, shuffle=False,
59
+ num_workers=args.num_workers, pin_memory=pin, drop_last=False,
60
+ )
61
+ return loader_train, loader_val, loader_test
62
+
63
+
64
+ def main():
65
+ ap = argparse.ArgumentParser()
66
+ # ----- mode -----
67
+ ap.add_argument("--mode", choices=["finetune", "test_only"], required=True)
68
+ ap.add_argument("--pretrained", default="/home/xinyi/Project/FD-VQA/src/checkpoints/kvq/qd_model.best.pt", help="pretrain model path")
69
+ # ----- data -----
70
+ ap.add_argument("--csv_path", default="/home/xinyi/Project/FD-VQA/metadata/SHORTS-SDR-DATASET_metadata.csv")
71
+ ap.add_argument("--db_path", default="/media/xinyi/server/video_dataset/shorts-hdr-dataset/sdr/")
72
+ ap.add_argument("--split_seed", type=int, default=0)
73
+ ap.add_argument("--test_ratio", type=float, default=0.2)
74
+ ap.add_argument("--val_ratio", type=float, default=0.1) # train 80%
75
+ # ----- video processing -----
76
+ ap.add_argument("--clip_len", type=int, default=16)
77
+ ap.add_argument("--resize", type=int, default=224)
78
+ ap.add_argument("--win", type=int, default=6)
79
+ ap.add_argument("--win_step", type=int, default=1)
80
+ # ----- runtime -----
81
+ ap.add_argument("--batch_size", type=int, default=8)
82
+ ap.add_argument("--num_workers", type=int, default=4)
83
+ ap.add_argument("--device", type=str, default="cuda")
84
+ ap.add_argument("--no_amp", action="store_true")
85
+ # ----- finetune hyperparams -----
86
+ ap.add_argument("--epochs", type=int, default=10)
87
+ ap.add_argument("--warmup_epochs", type=int, default=1)
88
+ ap.add_argument("--lr", type=float, default=1e-5)
89
+ ap.add_argument("--min_lr", type=float, default=1e-6)
90
+ ap.add_argument("--finetune_lr", type=float, default=5e-5)
91
+ ap.add_argument("--weight_decay", type=float, default=1e-2)
92
+ ap.add_argument("--clip_unfreeze_blocks", type=int, default=4)
93
+ ap.add_argument("--finetune_last_stage", action="store_true")
94
+ ap.add_argument("--patience", type=int, default=6)
95
+ # ----- save -----
96
+ ap.add_argument("--save_dir", type=str, default="checkpoints_transfer")
97
+ ap.add_argument("--save_name", type=str, default="transfer.pt")
98
+
99
+ # ----- MOS normalization control -----
100
+ # finetune: use target train mean/std
101
+ ap.add_argument("--test_only_norm", choices=["none", "use_source_ckpt"], default="use_source_ckpt")
102
+
103
+ args = ap.parse_args()
104
+ torch.manual_seed(args.split_seed)
105
+ device = torch.device(args.device)
106
+ amp = not bool(args.no_amp)
107
+
108
+ # ----------------------------
109
+ # Load target rows and split
110
+ # ----------------------------
111
+ rows = read_vid_mos_csv(args.csv_path)
112
+ if args.mode == "finetune":
113
+ csv_path = Path(args.csv_path)
114
+ if csv_path.name == "KVQ_TRAIN_metadata.csv":
115
+ # KVQ challenge split
116
+ val_csv = csv_path.parent / "KVQ_VAL_metadata.csv"
117
+ test_csv = csv_path.parent / "KVQ_TEST_metadata.csv"
118
+ train_rows = read_vid_mos_csv(str(csv_path))
119
+ val_rows = read_vid_mos_csv(str(val_csv))
120
+ test_rows = read_vid_mos_csv(str(test_csv))
121
+ print(f"[KVQ split] train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}")
122
+ else:
123
+ train_rows, val_rows, test_rows = split_rows(
124
+ rows, seed=args.split_seed, test_ratio=args.test_ratio, val_ratio=args.val_ratio
125
+ )
126
+ mos_train = np.array([mos for _vid, mos in train_rows], dtype=np.float32)
127
+ mos_mean = float(mos_train.mean()) if len(mos_train) else 0.0
128
+ mos_std = float(mos_train.std()) if len(mos_train) else 1.0
129
+ if mos_std <= 1e-8:
130
+ mos_std = 1.0
131
+ else:
132
+ # test_only
133
+ train_rows, val_rows, test_rows = [], [], rows
134
+ mos_mean, mos_std = None, None
135
+
136
+ # ----------------------------
137
+ # Build model, load pretrained
138
+ # ----------------------------
139
+ model = QD_MODEL(
140
+ clip_model="openai/clip-vit-base-patch16",
141
+ ).to(device)
142
+
143
+ ckpt = load_pretrained_weights(model, args.pretrained, device)
144
+ print(f"Loaded pretrained: {args.pretrained}")
145
+
146
+ # ----------------------------
147
+ # Decide normalization for test_only
148
+ # ----------------------------
149
+ if args.mode == "test_only":
150
+ if args.test_only_norm == "use_source_ckpt" and isinstance(ckpt, dict):
151
+ # 注:这样对 SRCC/PLCC 没影响
152
+ mos_mean = ckpt.get("mos_mean", None)
153
+ mos_std = ckpt.get("mos_std", None)
154
+ if mos_mean is None or mos_std is None:
155
+ mos_mean, mos_std = None, None
156
+ print("[warn] pretrained ckpt has no mos_mean/std, fallback to no normalization.")
157
+ else:
158
+ print(f"test_only uses source mos_mean/std from ckpt: mean={mos_mean:.4f}, std={mos_std:.4f}")
159
+ else:
160
+ print("test_only uses no MOS normalization.")
161
+
162
+ # ----------------------------
163
+ # Mode: test_only
164
+ # ----------------------------
165
+ if args.mode == "test_only":
166
+ ds_test = VQADataset(
167
+ test_rows, args.db_path,
168
+ clip_len=args.clip_len, size=args.resize, win=args.win, win_step=args.win_step,
169
+ mos_mean=mos_mean, mos_std=mos_std,
170
+ )
171
+ pin = str(device).startswith("cuda")
172
+ loader_test = torch.utils.data.DataLoader(
173
+ ds_test, batch_size=args.batch_size, shuffle=False,
174
+ num_workers=args.num_workers, pin_memory=pin, drop_last=False,
175
+ )
176
+ print("num_test_rows =", len(test_rows))
177
+ print("len(ds_test) =", len(ds_test))
178
+ print("len(loader_test) =", len(loader_test), "batch_size =", args.batch_size)
179
+
180
+ te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
181
+ model, loader_test, device,
182
+ optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
183
+ desc="TestOnly", show_pbar=True
184
+ )
185
+ print(f"TEST_ONLY | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
186
+ return
187
+
188
+ # ----------------------------
189
+ # DataLoaders
190
+ # ----------------------------
191
+ loader_train, loader_val, loader_test = make_loaders(
192
+ train_rows, val_rows, test_rows, args,
193
+ mos_mean=mos_mean, mos_std=mos_std, device=device
194
+ )
195
+
196
+ # ----------------------------
197
+ # Mode: finetune
198
+ # ----------------------------
199
+ model.freeze_clip_all()
200
+ did_unfreeze = False
201
+
202
+ clip_params_all = []
203
+ other_params_all = []
204
+ for name, p in model.named_parameters():
205
+ if name.startswith("encoder."):
206
+ clip_params_all.append(p)
207
+ else:
208
+ other_params_all.append(p)
209
+
210
+ param_groups = []
211
+ if other_params_all:
212
+ param_groups.append({"params": other_params_all, "lr": float(args.lr)})
213
+ if clip_params_all:
214
+ param_groups.append({"params": clip_params_all, "lr": float(args.finetune_lr)})
215
+
216
+ optim = torch.optim.AdamW(param_groups, weight_decay=float(args.weight_decay))
217
+ scheduler = build_scheduler(optim, args)
218
+
219
+ save_dir = Path(args.save_dir)
220
+ save_dir.mkdir(parents=True, exist_ok=True)
221
+ last_path = save_dir / args.save_name
222
+ best_path = save_dir / (Path(args.save_name).stem + ".best.pt")
223
+ best_weights_path = save_dir / (Path(args.save_name).stem + ".best_weights.pt")
224
+
225
+ best_val_srcc = -1e18
226
+ bad_epochs = 0
227
+
228
+ for epoch in range(1, int(args.epochs) + 1):
229
+ if (not did_unfreeze) and bool(args.finetune_last_stage) and epoch == (int(args.warmup_epochs) + 1):
230
+ model.unfreeze_clip_last_blocks(n_blocks=int(args.clip_unfreeze_blocks), also_unfreeze_ln=True)
231
+ did_unfreeze = True
232
+ print(f"[Finetune] Unfroze CLIP last {int(args.clip_unfreeze_blocks)} blocks")
233
+
234
+ tr_loss, tr_plcc, tr_srcc, tr_rmse = run_epoch(
235
+ model, loader_train, device,
236
+ optim=optim, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
237
+ desc=f"FT Train e{epoch:03d}", show_pbar=True
238
+ )
239
+ va_loss, va_plcc, va_srcc, va_rmse = run_epoch(
240
+ model, loader_val, device,
241
+ optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
242
+ desc=f"FT Val e{epoch:03d}", show_pbar=True
243
+ )
244
+ scheduler.step()
245
+
246
+ print(
247
+ f"epoch {epoch:03d} | "
248
+ f"train: loss={tr_loss:.4f} plcc={tr_plcc:.4f} srcc={tr_srcc:.4f} rmse={tr_rmse:.4f} | "
249
+ f"val: loss={va_loss:.4f} plcc={va_plcc:.4f} srcc={va_srcc:.4f} rmse={va_rmse:.4f}"
250
+ )
251
+
252
+ ckpt_out = {
253
+ "epoch": epoch,
254
+ "model": model.state_dict(),
255
+ "optim": optim.state_dict(),
256
+ "mos_mean": mos_mean,
257
+ "mos_std": mos_std,
258
+ "args": vars(args),
259
+ "best_val_srcc": best_val_srcc,
260
+ }
261
+ torch.save(ckpt_out, str(last_path))
262
+
263
+ if va_srcc > best_val_srcc:
264
+ best_val_srcc = va_srcc
265
+ bad_epochs = 0
266
+ ckpt_out["best_val_srcc"] = best_val_srcc
267
+ torch.save(ckpt_out, str(best_path))
268
+ torch.save(model.state_dict(), str(best_weights_path))
269
+ print(f" [best] val_srcc={best_val_srcc:.4f} -> saved {best_weights_path}")
270
+ else:
271
+ bad_epochs += 1
272
+ if bad_epochs >= int(args.patience):
273
+ print(f"[EarlyStop] val_srcc not improved for {bad_epochs} epochs. Stop.")
274
+ break
275
+
276
+ # load best and test
277
+ if best_weights_path.exists():
278
+ sd = torch.load(str(best_weights_path), map_location=device, weights_only=True)
279
+ model.load_state_dict(sd, strict=True)
280
+ print(f"Loaded best weights: {best_weights_path}")
281
+
282
+ te_loss, te_plcc, te_srcc, te_rmse = run_epoch(
283
+ model, loader_test, device,
284
+ optim=None, amp=amp, mos_mean=mos_mean, mos_std=mos_std,
285
+ desc="FT Test", show_pbar=True
286
+ )
287
+ print(f"FINETUNE TEST | loss={te_loss:.4f} plcc={te_plcc:.4f} srcc={te_srcc:.4f} rmse={te_rmse:.4f}")
288
+
289
+
290
+ if __name__ == "__main__":
291
+ main()
src/transfer_test_only.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["DECORD_DUPLICATE_WARNING_THRESHOLD"] = "1.0"
3
+ import argparse
4
+ import csv
5
+ from pathlib import Path
6
+ import torch
7
+ from torch.amp import autocast
8
+ from tqdm import tqdm
9
+
10
+ from train import VQADataset, com_loss, pearsonr, read_vid_mos_csv, spearmanr
11
+ from model.qd_model import QD_MODEL
12
+
13
+
14
+ def load_checkpoint(ckpt_path, device):
15
+ ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=True)
16
+ if isinstance(ckpt, dict) and "model" in ckpt:
17
+ return {
18
+ "state_dict": ckpt["model"],
19
+ "train_mos_mean": ckpt.get("mos_mean"),
20
+ "train_mos_std": ckpt.get("mos_std"),
21
+ "train_args": ckpt.get("args", {}),
22
+ "is_full_checkpoint": True,
23
+ }
24
+ if isinstance(ckpt, dict):
25
+ return {
26
+ "state_dict": ckpt,
27
+ "train_mos_mean": None,
28
+ "train_mos_std": None,
29
+ "train_args": {},
30
+ "is_full_checkpoint": False,
31
+ }
32
+ raise TypeError(f"Unsupported checkpoint type: {type(ckpt)!r}")
33
+
34
+ def infer_test_scale(rows):
35
+ mos_values = [float(mos) for _vid, mos in rows]
36
+ if not mos_values:
37
+ raise ValueError("Cannot infer test scale from empty rows")
38
+
39
+ lo = min(mos_values)
40
+ hi = max(mos_values)
41
+
42
+ if 0.0 <= lo and hi <= 1.0:
43
+ return 0.0, 1.0
44
+ if 1.0 <= lo and hi <= 5.0:
45
+ return 1.0, 5.0
46
+ if 0.0 <= lo and hi <= 5.0:
47
+ return 0.0, 5.0
48
+ return 0.0, 100.0
49
+
50
+ def linear_remap(x, src_min, src_max, dst_min, dst_max):
51
+ src_min = float(src_min)
52
+ src_max = float(src_max)
53
+ dst_min = float(dst_min)
54
+ dst_max = float(dst_max)
55
+
56
+ if abs(src_max - src_min) <= 1e-12:
57
+ raise ValueError("Source scale range must be non-zero")
58
+
59
+ return (x - src_min) / (src_max - src_min) * (dst_max - dst_min) + dst_min
60
+
61
+ def save_predictions_csv(save_path, vids, y_true_raw, pred_train_scale, pred_eval_scale):
62
+ save_path = Path(save_path)
63
+ save_path.parent.mkdir(parents=True, exist_ok=True)
64
+
65
+ with open(save_path, "w", newline="", encoding="utf-8") as f:
66
+ writer = csv.writer(f)
67
+ writer.writerow(["vid", "y_true_raw", "pred_train_scale", "pred_eval_scale"])
68
+ for vid, y_true, pred_train, pred_eval in zip(
69
+ vids,
70
+ y_true_raw.tolist(),
71
+ pred_train_scale.tolist(),
72
+ pred_eval_scale.tolist(),
73
+ strict=False,
74
+ ):
75
+ writer.writerow([vid, float(y_true), float(pred_train), float(pred_eval)])
76
+
77
+ return save_path
78
+
79
+ @torch.no_grad()
80
+ def evaluate_and_collect(
81
+ model,
82
+ loader,
83
+ device,
84
+ *,
85
+ amp=True,
86
+ train_mos_mean,
87
+ train_mos_std,
88
+ train_scale_min,
89
+ train_scale_max,
90
+ test_scale_min,
91
+ test_scale_max,
92
+ desc="",
93
+ show_pbar=True,
94
+ log_interval=10,
95
+ ):
96
+ model.eval()
97
+
98
+ losses = []
99
+ y_all = []
100
+ yhat_all = []
101
+ vids_all = []
102
+
103
+ it = loader
104
+ if show_pbar:
105
+ it = tqdm(loader, desc=desc, leave=False, dynamic_ncols=True)
106
+
107
+ for step, (rgb, w_art, w_str, y, vid) in enumerate(it, start=1):
108
+ rgb = rgb.to(device, non_blocking=True)
109
+ w_art = w_art.to(device, non_blocking=True)
110
+ w_str = w_str.to(device, non_blocking=True)
111
+ y = y.to(device, non_blocking=True).float()
112
+
113
+ device_type = "cuda" if str(device).startswith("cuda") else "cpu"
114
+ with autocast(device_type=device_type, enabled=(amp and device_type == "cuda")):
115
+ yhat, _aux = model(rgb, w_art, w_str)
116
+ loss, _loss_reg, _loss_rank = com_loss(yhat, y)
117
+
118
+ losses.append(loss.detach().float().cpu())
119
+ y_all.append(y.detach().float().cpu())
120
+ yhat_all.append(yhat.detach().float().cpu())
121
+ vids_all.extend(list(vid))
122
+
123
+ if show_pbar and (step % int(log_interval) == 0 or step == len(loader)):
124
+ avg_loss_so_far = torch.stack(losses).mean().item()
125
+ it.set_postfix({"loss": f"{avg_loss_so_far:.4f}"})
126
+
127
+ if y_all:
128
+ y_all = torch.cat(y_all, dim=0)
129
+ yhat_all = torch.cat(yhat_all, dim=0)
130
+ else:
131
+ y_all = torch.empty(0)
132
+ yhat_all = torch.empty(0)
133
+
134
+ y_true_raw = y_all * float(train_mos_std) + float(train_mos_mean)
135
+ pred_train_scale = yhat_all * float(train_mos_std) + float(train_mos_mean)
136
+ pred_eval_scale = linear_remap(
137
+ pred_train_scale,
138
+ src_min=float(train_scale_min),
139
+ src_max=float(train_scale_max),
140
+ dst_min=float(test_scale_min),
141
+ dst_max=float(test_scale_max),
142
+ )
143
+
144
+ plcc = pearsonr(y_true_raw, pred_eval_scale).item() if y_true_raw.numel() > 1 else 0.0
145
+ srcc = spearmanr(y_true_raw, pred_eval_scale).item() if y_true_raw.numel() > 1 else 0.0
146
+ rmse = (
147
+ torch.sqrt(torch.mean((pred_eval_scale - y_true_raw) ** 2)).item()
148
+ if y_true_raw.numel() > 0
149
+ else 0.0
150
+ )
151
+ avg_loss = torch.stack(losses).mean().item() if losses else 0.0
152
+
153
+ return {
154
+ "loss": avg_loss,
155
+ "plcc": plcc,
156
+ "srcc": srcc,
157
+ "rmse": rmse,
158
+ "vids": vids_all,
159
+ "y_true_raw": y_true_raw,
160
+ "pred_train_scale": pred_train_scale,
161
+ "pred_eval_scale": pred_eval_scale,
162
+ }
163
+
164
+ def main():
165
+ ap = argparse.ArgumentParser()
166
+ ap.add_argument("--ckpt_path", type=str, default="/home/xinyi/Project/FD-VQA/src/checkpoints/lsvq/qd_model.best.pt")
167
+ ap.add_argument("--csv_path", type=str, default="/home/xinyi/Project/FD-VQA/metadata/KVQ_metadata.csv")
168
+ ap.add_argument("--db_path", type=str, default="/media/xinyi/server/video_dataset/KVQ")
169
+
170
+ ap.add_argument("--clip_len", type=int, default=16)
171
+ ap.add_argument("--resize", type=int, default=224)
172
+ ap.add_argument("--win", type=int, default=6)
173
+ ap.add_argument("--win_step", type=int, default=1)
174
+
175
+ ap.add_argument("--batch_size", type=int, default=8)
176
+ ap.add_argument("--num_workers", type=int, default=2)
177
+ ap.add_argument("--device", type=str, default="cuda")
178
+ ap.add_argument("--no_amp", action="store_true")
179
+
180
+ ap.add_argument("--train_scale_min", type=float, default=0.0)
181
+ ap.add_argument("--train_scale_max", type=float, default=100.0)
182
+ ap.add_argument("--test_scale_min", type=float, default=1.0)
183
+ ap.add_argument("--test_scale_max", type=float, default=5.0)
184
+
185
+ ap.add_argument("--save_pred_csv", type=str, default="/home/xinyi/Project/FD-VQA/src/transfer_test/transfer_test_only_konvid_1k.csv")
186
+ args = ap.parse_args()
187
+
188
+ device = torch.device(args.device)
189
+ amp = not bool(args.no_amp)
190
+ ckpt_info = load_checkpoint(Path(args.ckpt_path), device)
191
+
192
+ train_mos_mean = ckpt_info["train_mos_mean"]
193
+ train_mos_std = ckpt_info["train_mos_std"]
194
+ if train_mos_mean is None or train_mos_std is None:
195
+ raise ValueError(
196
+ "Prefer loading *.best.pt / *.pt, or pass --train_mos_mean and --train_mos_std manually."
197
+ )
198
+ if float(train_mos_std) <= 1e-8:
199
+ raise ValueError("train_mos_std must be > 0")
200
+
201
+ rows = read_vid_mos_csv(args.csv_path)
202
+ if not rows:
203
+ raise ValueError(f"No rows found in csv: {args.csv_path}")
204
+
205
+ if args.test_scale_min is None or args.test_scale_max is None:
206
+ inferred_test_scale_min, inferred_test_scale_max = infer_test_scale(rows)
207
+ test_scale_min = inferred_test_scale_min
208
+ test_scale_max = inferred_test_scale_max
209
+ else:
210
+ test_scale_min = float(args.test_scale_min)
211
+ test_scale_max = float(args.test_scale_max)
212
+
213
+ dataset = VQADataset(
214
+ rows,
215
+ args.db_path,
216
+ clip_len=args.clip_len,
217
+ size=args.resize,
218
+ win=args.win,
219
+ win_step=args.win_step,
220
+ mos_mean=float(train_mos_mean),
221
+ mos_std=float(train_mos_std),
222
+ )
223
+
224
+ pin = str(device).startswith("cuda")
225
+ loader = torch.utils.data.DataLoader(
226
+ dataset,
227
+ batch_size=int(args.batch_size),
228
+ shuffle=False,
229
+ num_workers=int(args.num_workers),
230
+ pin_memory=pin,
231
+ drop_last=False,
232
+ prefetch_factor=4 if int(args.num_workers) > 0 else None,
233
+ )
234
+
235
+ model = QD_MODEL(
236
+ clip_model="openai/clip-vit-base-patch16",
237
+ ).to(device)
238
+ model.load_state_dict(ckpt_info["state_dict"], strict=True)
239
+
240
+ print(f"Loaded checkpoint: {args.ckpt_path}")
241
+ print(f"Training normalization: mean={float(train_mos_mean):.6f}, std={float(train_mos_std):.6f}")
242
+ print(
243
+ f"Scale mapping: train=[{float(args.train_scale_min):.3f}, {float(args.train_scale_max):.3f}] -> "
244
+ f"test=[{float(test_scale_min):.3f}, {float(test_scale_max):.3f}]"
245
+ )
246
+ print(f"Test rows: {len(rows)}")
247
+
248
+ metrics = evaluate_and_collect(
249
+ model,
250
+ loader,
251
+ device,
252
+ amp=amp,
253
+ train_mos_mean=float(train_mos_mean),
254
+ train_mos_std=float(train_mos_std),
255
+ train_scale_min=float(args.train_scale_min),
256
+ train_scale_max=float(args.train_scale_max),
257
+ test_scale_min=float(test_scale_min),
258
+ test_scale_max=float(test_scale_max),
259
+ desc="Cross-dataset test",
260
+ show_pbar=True,
261
+ log_interval=10,
262
+ )
263
+
264
+ print(
265
+ "TEST | "
266
+ f"loss={metrics['loss']:.4f} "
267
+ f"plcc={metrics['plcc']:.4f} "
268
+ f"srcc={metrics['srcc']:.4f} "
269
+ f"rmse={metrics['rmse']:.4f}"
270
+ )
271
+
272
+ if args.save_pred_csv:
273
+ save_path = save_predictions_csv(
274
+ args.save_pred_csv,
275
+ metrics["vids"],
276
+ metrics["y_true_raw"],
277
+ metrics["pred_train_scale"],
278
+ metrics["pred_eval_scale"],
279
+ )
280
+ print(f"Saved predictions to: {save_path}")
281
+
282
+
283
+ if __name__ == "__main__":
284
+ main()