Thefalley commited on
Commit
a3596f8
·
verified ·
1 Parent(s): 0de91f7

Add inference script with decoder

Browse files
Files changed (1) hide show
  1. inference.py +173 -0
inference.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 05_inference_test.py — Inference for YOLOv4-tiny raw ONNX (float or INT8).
3
+
4
+ Reuses the decoder pattern from clean_yolov4/05b_inference_int8_raw.py
5
+ but with YOLOv4-tiny's 2-head config.
6
+ """
7
+ import os, sys, time, argparse
8
+ import numpy as np, cv2, requests
9
+ from io import BytesIO
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ import onnxruntime as ort
12
+
13
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
14
+
15
+ # YOLOv4-tiny constants
16
+ ANCHORS = [(10,14),(23,27),(37,58),(81,82),(135,169),(344,319)]
17
+ # (stride, anchor_indices, scale_xy) — order matches DarknetRaw export
18
+ CFG_HEADS = [
19
+ (32, [3, 4, 5], 1.05), # out0 = stride 32 (13x13)
20
+ (16, [1, 2, 3], 1.05), # out1 = stride 16 (26x26)
21
+ ]
22
+
23
+ INPUT_SIZE = 416
24
+ SCORE_THR = 0.30
25
+ NMS_THR = 0.45
26
+
27
+ COCO = ["person","bicycle","car","motorcycle","airplane","bus","train","truck","boat",
28
+ "traffic light","fire hydrant","stop sign","parking meter","bench","bird","cat",
29
+ "dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack","umbrella",
30
+ "handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite",
31
+ "baseball bat","baseball glove","skateboard","surfboard","tennis racket","bottle",
32
+ "wine glass","cup","fork","knife","spoon","bowl","banana","apple","sandwich",
33
+ "orange","broccoli","carrot","hot dog","pizza","donut","cake","chair","couch",
34
+ "potted plant","bed","dining table","toilet","tv","laptop","mouse","remote",
35
+ "keyboard","cell phone","microwave","oven","toaster","sink","refrigerator","book",
36
+ "clock","vase","scissors","teddy bear","hair drier","toothbrush"]
37
+
38
+ np.random.seed(42)
39
+ PALETTE = [(int(r), int(g), int(b)) for r, g, b in np.random.randint(60, 255, (80, 3))]
40
+
41
+
42
+ def get_font(size):
43
+ for f in ("arialbd.ttf", "arial.ttf", "segoeui.ttf"):
44
+ try: return ImageFont.truetype(f, size)
45
+ except Exception: continue
46
+ return ImageFont.load_default()
47
+
48
+
49
+ def sigmoid(x):
50
+ return 1.0 / (1.0 + np.exp(-np.clip(x, -50, 50)))
51
+
52
+
53
+ def letterbox_nchw(rgb, size=INPUT_SIZE):
54
+ h, w = rgb.shape[:2]; s = min(size/h, size/w)
55
+ nh, nw = int(round(h*s)), int(round(w*s))
56
+ resized = cv2.resize(rgb, (nw, nh))
57
+ pad = np.full((size, size, 3), 114, np.uint8); pad[:nh,:nw] = resized
58
+ chw = pad.astype(np.float32).transpose(2,0,1) / 255.0
59
+ return np.expand_dims(chw, 0), s
60
+
61
+
62
+ def decode_one_head(raw, stride, anchor_idxs, scale_xy):
63
+ _, ch, H, W = raw.shape
64
+ n_anchors = len(anchor_idxs)
65
+ n_classes = ch // n_anchors - 5
66
+ x = raw.reshape(1, n_anchors, 5 + n_classes, H, W).transpose(0, 3, 4, 1, 2)[0]
67
+ txty = sigmoid(x[..., 0:2]) * scale_xy - (scale_xy - 1) / 2
68
+ twth = np.exp(np.clip(x[..., 2:4], -10, 10))
69
+ obj = sigmoid(x[..., 4:5])
70
+ cls = sigmoid(x[..., 5:])
71
+ anc = np.array([[ANCHORS[i][0]/stride, ANCHORS[i][1]/stride] for i in anchor_idxs], dtype=np.float32)
72
+ twth = twth * anc[None, None, :, :]
73
+ yy, xx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
74
+ grid = np.stack([xx, yy], axis=-1).astype(np.float32)
75
+ txty = txty + grid[:, :, None, :]
76
+ txty *= stride; twth *= stride
77
+ pred = np.concatenate([txty, twth, obj, cls], axis=-1).reshape(-1, 5 + n_classes)
78
+ return pred
79
+
80
+
81
+ def decode_all(raws, ratio):
82
+ all_pred = [decode_one_head(r, s, a, sxy) for r, (s, a, sxy) in zip(raws, CFG_HEADS)]
83
+ pred = np.concatenate(all_pred, axis=0)
84
+ obj = pred[:, 4]; cls = pred[:, 5:]
85
+ cls_id = np.argmax(cls, axis=1)
86
+ cls_score = cls[np.arange(len(cls)), cls_id]
87
+ score = obj * cls_score
88
+ keep = score > SCORE_THR
89
+ pred = pred[keep]; score = score[keep]; cls_id = cls_id[keep]
90
+ if len(pred) == 0: return []
91
+ cx, cy, w, h = pred[:,0], pred[:,1], pred[:,2], pred[:,3]
92
+ x1 = (cx-w/2)/ratio; y1 = (cy-h/2)/ratio; x2 = (cx+w/2)/ratio; y2 = (cy+h/2)/ratio
93
+ dets = [{"class": int(cls_id[i]), "score": float(score[i]),
94
+ "bbox":[float(x1[i]), float(y1[i]), float(x2[i]), float(y2[i])]}
95
+ for i in range(len(pred))]
96
+ dets.sort(key=lambda d: -d["score"])
97
+ keep_d = []
98
+ while dets:
99
+ keep_d.append(dets[0]); rest = []
100
+ for d in dets[1:]:
101
+ if d["class"] != keep_d[-1]["class"]:
102
+ rest.append(d); continue
103
+ ax1,ay1,ax2,ay2 = keep_d[-1]["bbox"]; bx1,by1,bx2,by2 = d["bbox"]
104
+ iw = max(0,min(ax2,bx2)-max(ax1,bx1)); ih = max(0,min(ay2,by2)-max(ay1,by1))
105
+ inter = iw*ih; aa=max(0,(ax2-ax1)*(ay2-ay1)); ab=max(0,(bx2-bx1)*(by2-by1))
106
+ iou = inter/(aa+ab-inter+1e-9)
107
+ if iou < NMS_THR: rest.append(d)
108
+ dets = rest
109
+ return keep_d
110
+
111
+
112
+ def draw(pil, dets):
113
+ img = pil.copy(); d = ImageDraw.Draw(img)
114
+ W, H = img.size
115
+ th = max(3, min(W,H)//200); font = get_font(max(14, min(W,H)//40))
116
+ for x in dets:
117
+ x1, y1, x2, y2 = x["bbox"]
118
+ x1 = max(0, min(x1,W-1)); y1 = max(0, min(y1,H-1))
119
+ x2 = max(0, min(x2,W-1)); y2 = max(0, min(y2,H-1))
120
+ cls = x["class"]; cname = COCO[cls]; color = PALETTE[cls % len(PALETTE)]
121
+ for t in range(th):
122
+ d.rectangle([x1-t, y1-t, x2+t, y2+t], outline=color)
123
+ label = f"{cname} {x['score']*100:.0f}%"
124
+ bb = d.textbbox((x1, y1-18), label, font=font)
125
+ d.rectangle(bb, fill=color); d.text((bb[0], bb[1]), label, fill=(0,0,0), font=font)
126
+ return img
127
+
128
+
129
+ def main():
130
+ ap = argparse.ArgumentParser()
131
+ ap.add_argument("--onnx", default=os.path.join(SCRIPT_DIR, "out_onnx", "yolov4-tiny-416_float_raw.onnx"))
132
+ ap.add_argument("--out-dir", default=os.path.join(SCRIPT_DIR, "inference"))
133
+ args = ap.parse_args()
134
+ os.makedirs(args.out_dir, exist_ok=True)
135
+ if not os.path.isfile(args.onnx):
136
+ print(f"[FAIL] no existe: {args.onnx}"); return 1
137
+ print(f"Loading: {args.onnx} ({os.path.getsize(args.onnx)/1e6:.1f} MB)")
138
+ sess = ort.InferenceSession(args.onnx, providers=["CPUExecutionProvider"])
139
+ inp = sess.get_inputs()[0].name
140
+ print(f" inputs: {[i.name for i in sess.get_inputs()]}")
141
+ print(f" outputs: {[(o.name, o.shape) for o in sess.get_outputs()]}")
142
+ tests = [
143
+ ("dog", "https://raw.githubusercontent.com/pjreddie/darknet/master/data/dog.jpg"),
144
+ ("traffic", "http://images.cocodataset.org/val2017/000000011197.jpg"),
145
+ ("skaters", "http://images.cocodataset.org/val2017/000000087038.jpg"),
146
+ ("kitchen", "http://images.cocodataset.org/val2017/000000037777.jpg"),
147
+ ("market", "http://images.cocodataset.org/val2017/000000289343.jpg"),
148
+ ("parking", "http://images.cocodataset.org/val2017/000000017627.jpg"),
149
+ ("dining", "http://images.cocodataset.org/val2017/000000080340.jpg"),
150
+ ("bus", "http://images.cocodataset.org/val2017/000000000785.jpg"),
151
+ ]
152
+ for name, url in tests:
153
+ try:
154
+ r = requests.get(url, timeout=30); r.raise_for_status()
155
+ pil = Image.open(BytesIO(r.content)).convert("RGB")
156
+ except Exception as e:
157
+ print(f"[skip {name}] {e}"); continue
158
+ rgb = np.array(pil)
159
+ blob, ratio = letterbox_nchw(rgb, INPUT_SIZE)
160
+ t0 = time.time()
161
+ outs = sess.run(None, {inp: blob})
162
+ t = (time.time()-t0)*1000
163
+ dets = decode_all(outs, ratio)
164
+ annotated = draw(pil, dets)
165
+ out_path = os.path.join(args.out_dir, f"{name}.png")
166
+ annotated.save(out_path)
167
+ print(f" {name:>10s}: {len(dets):>2d} dets in {t:6.1f} ms")
168
+ for d in dets[:8]:
169
+ print(f" {COCO[d['class']]:>16s} {d['score']*100:5.1f}%")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ sys.exit(main() or 0)