fangmingguo commited on
Commit
ddf1c32
·
verified ·
1 Parent(s): 81f398f

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.axmodel filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.axmodel filter=lfs diff=lfs merge=lfs -text
37
+ result_aquarium_yolov8.jpg filter=lfs diff=lfs merge=lfs -text
38
+ test.png filter=lfs diff=lfs merge=lfs -text
AX637/aquarium_yolov8s.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55a057585f4a9a8f3136a1092c90a44117cdc1b744901c1e9b5b50892abf0d90
3
+ size 11363800
AX650/aquarium_animials.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30f83af29976d26bdba670ea91e7d5769758e4be1e56b2f232493d568f4443ba
3
+ size 11832954
aquarium_animals_20260404_002650_job_115_best_0.48.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2189acce8e3bf4673fc8a2fa4a2b942e53079355a7e083dc7133ae606fdd3530
3
+ size 44752247
aquarium_animials_cut.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b84f26d0a592e2dcf4109ee4ddb84b185b37ed1b0c9b8dac6c92f9ec627e00df
3
+ size 44552040
aquarium_calib.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d56e7f04a6093562b3587459d8c9842a57df2465d50d5a5cdcebbd3c15ddf3c7
3
+ size 23244800
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "ONNX",
3
+ "npu_mode": "NPU3",
4
+ "quant": {
5
+ "input_configs": [
6
+ {
7
+ "tensor_name": "images",
8
+ "calibration_dataset": "./aquarium_calib.tar",
9
+ "calibration_size": 256,
10
+ "calibration_mean": [0, 0, 0],
11
+ "calibration_std": [255.0, 255.0, 255.0]
12
+ }
13
+ ],
14
+ "calibration_method": "MinMax",
15
+ "precision_analysis": true,
16
+ "precision_analysis_method": "EndToEnd"
17
+ },
18
+ "input_processors": [
19
+ {
20
+ "tensor_name": "images",
21
+ "tensor_format": "BGR",
22
+ "src_format": "BGR",
23
+ "src_dtype": "U8",
24
+ "src_layout": "NHWC"
25
+ }
26
+ ],
27
+ "output_processors": [
28
+ { "tensor_name": "stride_8_cls" },
29
+ { "tensor_name": "stride_8_bbox" },
30
+ { "tensor_name": "stride_16_cls" },
31
+ { "tensor_name": "stride_16_bbox" },
32
+ { "tensor_name": "stride_32_cls" },
33
+ { "tensor_name": "stride_32_bbox" }
34
+ ],
35
+ "compiler": {
36
+ "check": 0
37
+ }
38
+ }
infer_yolov8_pyax.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import axengine as ort
12
+
13
+ logging.basicConfig(
14
+ level=logging.DEBUG,
15
+ format='[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
16
+ datefmt='%H:%M:%S',
17
+ )
18
+ logger = logging.getLogger("Aquarium-YOLOv8-6way")
19
+
20
+ PROB_THRESHOLD = 0.45
21
+ NMS_THRESHOLD = 0.45
22
+ REG_MAX = 16
23
+ STRIDES = (8, 16, 32)
24
+ DEFAULT_NAMES = ["fish", "turtle", "shrimp", "crab", "snail"]
25
+ DEFAULT_COLORS = [
26
+ (56, 56, 255),
27
+ (151, 157, 255),
28
+ (31, 112, 255),
29
+ (29, 178, 255),
30
+ (49, 210, 207),
31
+ ]
32
+
33
+
34
+ def infer_hw_layout(shape):
35
+ shape = list(shape)
36
+ if len(shape) == 4 and shape[-1] == 3:
37
+ h = int(shape[1] or 640)
38
+ w = int(shape[2] or 640)
39
+ return h, w, "NHWC"
40
+ if len(shape) == 4 and shape[1] == 3:
41
+ h = int(shape[2] or 640)
42
+ w = int(shape[3] or 640)
43
+ return h, w, "NCHW"
44
+ return 640, 640, "NCHW"
45
+
46
+
47
+ def letterbox(bgr, dst_h, dst_w, pad_value=114):
48
+ h, w = bgr.shape[:2]
49
+ scale = min(dst_h / h, dst_w / w)
50
+ new_h, new_w = int(round(h * scale)), int(round(w * scale))
51
+ resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
52
+ top = (dst_h - new_h) // 2
53
+ bot = dst_h - new_h - top
54
+ left = (dst_w - new_w) // 2
55
+ right = dst_w - new_w - left
56
+ out = cv2.copyMakeBorder(
57
+ resized, top, bot, left, right, cv2.BORDER_CONSTANT,
58
+ value=(pad_value, pad_value, pad_value),
59
+ )
60
+ meta = {
61
+ "src_h": h, "src_w": w,
62
+ "dst_h": dst_h, "dst_w": dst_w,
63
+ "scale": scale,
64
+ "pad_top": top, "pad_left": left,
65
+ }
66
+ return out, meta
67
+
68
+
69
+ def _to_hwc(t, c_expected):
70
+ a = np.asarray(t)
71
+ if a.ndim == 3:
72
+ a = a[None, ...]
73
+ if a.shape[-1] == c_expected:
74
+ return a[0]
75
+ if a.shape[1] == c_expected:
76
+ return np.transpose(a[0], (1, 2, 0))
77
+ raise ValueError(f"unexpected shape {a.shape!r} for C={c_expected}")
78
+
79
+
80
+ def group_outputs(out_names, outs, cls_num):
81
+ name_to_arr = dict(zip(out_names, outs))
82
+ by_stride = {}
83
+
84
+ if all(f"stride_{s}_{suf}" in name_to_arr for s in STRIDES for suf in ("cls", "bbox")):
85
+ for s in STRIDES:
86
+ by_stride[s] = (
87
+ _to_hwc(name_to_arr[f"stride_{s}_cls"], cls_num),
88
+ _to_hwc(name_to_arr[f"stride_{s}_bbox"], 4 * REG_MAX),
89
+ )
90
+ return by_stride
91
+
92
+ cls_outs, bb_outs = [], []
93
+ for t in outs:
94
+ a = np.asarray(t)
95
+ if a.ndim == 3:
96
+ a = a[None, ...]
97
+ c_last, c_first = a.shape[-1], a.shape[1]
98
+ if cls_num in (c_last, c_first):
99
+ cls_outs.append(a)
100
+ elif (4 * REG_MAX) in (c_last, c_first):
101
+ bb_outs.append(a)
102
+ cls_outs.sort(key=lambda x: -(x.shape[1] * x.shape[2]))
103
+ bb_outs.sort(key=lambda x: -(x.shape[1] * x.shape[2]))
104
+ if len(cls_outs) != 3 or len(bb_outs) != 3:
105
+ raise ValueError(
106
+ f"expected 3 cls + 3 bbox, got {len(cls_outs)} cls + {len(bb_outs)} bbox"
107
+ )
108
+ for s, ct, bt in zip(STRIDES, cls_outs, bb_outs):
109
+ by_stride[s] = (_to_hwc(ct, cls_num), _to_hwc(bt, 4 * REG_MAX))
110
+ return by_stride
111
+
112
+
113
+ def decode_one_scale(stride, cls_hwc, bbox_hwc, prob_thr, dst_h, dst_w):
114
+ hf, wf, _ = cls_hwc.shape
115
+ assert bbox_hwc.shape[:2] == (hf, wf) and bbox_hwc.shape[2] == 4 * REG_MAX
116
+
117
+ logit_thr = -np.log(1.0 / prob_thr - 1.0) if 0 < prob_thr < 1 else -np.inf
118
+ cls_max = cls_hwc.max(axis=2)
119
+ cls_arg = cls_hwc.argmax(axis=2)
120
+ keep = cls_max >= logit_thr
121
+ if not keep.any():
122
+ return (np.empty((0, 4), np.float32),
123
+ np.empty((0,), np.float32),
124
+ np.empty((0,), np.int32))
125
+
126
+ yi, xi = np.where(keep)
127
+ logits = cls_max[yi, xi].astype(np.float64)
128
+ probs = (1.0 / (1.0 + np.exp(-logits))).astype(np.float32)
129
+ labels = cls_arg[yi, xi].astype(np.int32)
130
+
131
+ dfl = bbox_hwc[yi, xi].reshape(-1, 4, REG_MAX).astype(np.float64)
132
+ dfl = dfl - dfl.max(axis=-1, keepdims=True)
133
+ e = np.exp(dfl)
134
+ sm = e / e.sum(axis=-1, keepdims=True)
135
+ proj = np.arange(REG_MAX, dtype=np.float64)
136
+ ltrb = (sm * proj).sum(axis=-1) * stride
137
+
138
+ cx = (xi + 0.5) * stride
139
+ cy = (yi + 0.5) * stride
140
+ x0 = cx - ltrb[:, 0]
141
+ y0 = cy - ltrb[:, 1]
142
+ x1 = cx + ltrb[:, 2]
143
+ y1 = cy + ltrb[:, 3]
144
+ boxes = np.stack([x0, y0, x1, y1], axis=1).astype(np.float32)
145
+ boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, dst_w - 1)
146
+ boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, dst_h - 1)
147
+ return boxes, probs, labels
148
+
149
+
150
+ def per_class_nms(boxes_xyxy, scores, labels, score_thr, iou_thr):
151
+ if len(boxes_xyxy) == 0:
152
+ return np.empty((0,), np.int64)
153
+ keep_global = []
154
+ for c in np.unique(labels):
155
+ idx = np.where(labels == c)[0]
156
+ rects_xywh = np.column_stack([
157
+ boxes_xyxy[idx, 0],
158
+ boxes_xyxy[idx, 1],
159
+ boxes_xyxy[idx, 2] - boxes_xyxy[idx, 0],
160
+ boxes_xyxy[idx, 3] - boxes_xyxy[idx, 1],
161
+ ]).tolist()
162
+ kept = cv2.dnn.NMSBoxes(rects_xywh, scores[idx].tolist(), score_thr, iou_thr)
163
+ if isinstance(kept, np.ndarray):
164
+ kept = kept.flatten().tolist()
165
+ keep_global.extend(int(idx[k]) for k in kept)
166
+ return np.array(keep_global, dtype=np.int64)
167
+
168
+
169
+ def unletterbox(boxes_xyxy, meta):
170
+ if len(boxes_xyxy) == 0:
171
+ return boxes_xyxy
172
+ out = boxes_xyxy.copy()
173
+ out[:, [0, 2]] -= meta["pad_left"]
174
+ out[:, [1, 3]] -= meta["pad_top"]
175
+ out /= meta["scale"]
176
+ out[:, [0, 2]] = np.clip(out[:, [0, 2]], 0, meta["src_w"] - 1)
177
+ out[:, [1, 3]] = np.clip(out[:, [1, 3]], 0, meta["src_h"] - 1)
178
+ return out
179
+
180
+
181
+ def draw(img, boxes_xyxy, scores, labels, names, colors):
182
+ vis = img.copy()
183
+ for b, s, c in zip(boxes_xyxy, scores, labels):
184
+ x0, y0, x1, y1 = [int(round(v)) for v in b]
185
+ color = colors[int(c) % len(colors)]
186
+ nm = names[int(c)] if 0 <= int(c) < len(names) else str(int(c))
187
+ cv2.rectangle(vis, (x0, y0), (x1, y1), color, 2)
188
+ text = f"{nm} {float(s):.2f}"
189
+ (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
190
+ y_text = max(th + 2, y0)
191
+ cv2.rectangle(vis, (x0, y_text - th - 2), (x0 + tw + 2, y_text + 1), color, -1)
192
+ cv2.putText(vis, text, (x0 + 1, y_text - 2),
193
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
194
+ return vis
195
+
196
+
197
+ def main():
198
+ ap = argparse.ArgumentParser(description="aquarium YOLOv8s 6-way axmodel inference (AXERARuntime)")
199
+ ap.add_argument('--model-path', type=str, default='aquarium_yolov8s_6way.axmodel')
200
+ ap.add_argument('--test-img', type=str, default='test.jpg')
201
+ ap.add_argument('--img-save-path', type=str, default='result_aquarium_yolov8.jpg')
202
+ ap.add_argument('--score-thres', type=float, default=PROB_THRESHOLD)
203
+ ap.add_argument('--nms-thres', type=float, default=NMS_THRESHOLD)
204
+ ap.add_argument('--repeat', type=int, default=1)
205
+ ap.add_argument('--names', type=str, default=",".join(DEFAULT_NAMES))
206
+ ap.add_argument('--providers', type=str, default='AxEngineExecutionProvider')
207
+ opt = ap.parse_args()
208
+
209
+ if not os.path.exists(opt.model_path):
210
+ logger.error(f"Model not found: {opt.model_path}")
211
+ sys.exit(1)
212
+ if not os.path.exists(opt.test_img):
213
+ logger.error(f"Image not found: {opt.test_img}")
214
+ sys.exit(1)
215
+
216
+ names = [s.strip() for s in opt.names.split(",") if s.strip()]
217
+ cls_num = len(names)
218
+
219
+ t0 = time.time()
220
+ providers = [p.strip() for p in opt.providers.split(",") if p.strip()] or None
221
+ sess = ort.InferenceSession(opt.model_path, providers=providers)
222
+ logger.debug(f"\033[1;31mLoad model time = {(time.time() - t0) * 1000:.2f} ms\033[0m")
223
+
224
+ inp = sess.get_inputs()[0]
225
+ input_name = inp.name
226
+ m_h, m_w, layout = infer_hw_layout(inp.shape)
227
+
228
+ img = cv2.imread(opt.test_img)
229
+ if img is None:
230
+ logger.error(f"Failed to read image: {opt.test_img}")
231
+ sys.exit(1)
232
+
233
+ t0 = time.time()
234
+ pad_bgr, meta = letterbox(img, m_h, m_w, pad_value=114)
235
+ rgb = cv2.cvtColor(pad_bgr, cv2.COLOR_BGR2RGB)
236
+ if layout == "NHWC":
237
+ input_tensor = rgb[None, ...].astype(np.uint8)
238
+ else:
239
+ input_tensor = np.transpose(rgb, (2, 0, 1))[None, ...].astype(np.uint8)
240
+ logger.debug(f"\033[1;31mPre-process time = {(time.time() - t0) * 1000:.2f} ms\033[0m")
241
+
242
+ out_infos = sess.get_outputs()
243
+ out_names = [o.name for o in out_infos]
244
+
245
+ times = []
246
+ outs = None
247
+ for _ in range(max(opt.repeat, 1)):
248
+ t0 = time.time()
249
+ outs = sess.run(None, {input_name: input_tensor})
250
+ times.append((time.time() - t0) * 1000.0)
251
+ logger.debug(
252
+ f"\033[1;31mForward time min/avg/max = "
253
+ f"{min(times):.2f}/{sum(times)/len(times):.2f}/{max(times):.2f} ms (n={len(times)})\033[0m"
254
+ )
255
+
256
+ assert outs is not None
257
+ if len(outs) != 6:
258
+ raise ValueError(f"need 6 outputs, got {len(outs)}: {out_names}")
259
+
260
+ t0 = time.time()
261
+ by_s = group_outputs(out_names, outs, cls_num)
262
+ boxes_all, scores_all, labels_all = [], [], []
263
+ for s in STRIDES:
264
+ cl, bb = by_s[s]
265
+ b, p, l = decode_one_scale(s, cl, bb, opt.score_thres, m_h, m_w)
266
+ if len(b):
267
+ boxes_all.append(b); scores_all.append(p); labels_all.append(l)
268
+
269
+ if boxes_all:
270
+ boxes = np.concatenate(boxes_all)
271
+ scores = np.concatenate(scores_all)
272
+ labels = np.concatenate(labels_all)
273
+ keep = per_class_nms(boxes, scores, labels, opt.score_thres, opt.nms_thres)
274
+ boxes = unletterbox(boxes[keep], meta)
275
+ scores = scores[keep]; labels = labels[keep]
276
+ else:
277
+ boxes = np.empty((0, 4), np.float32)
278
+ scores = np.empty((0,), np.float32)
279
+ labels = np.empty((0,), np.int32)
280
+ logger.debug(f"\033[1;31mPost-process time = {(time.time() - t0) * 1000:.2f} ms\033[0m")
281
+
282
+ counts = {n: 0 for n in names}
283
+ logger.info(f"\033[1;32mDetections: {len(boxes)}\033[0m")
284
+ for b, s, c in zip(boxes, scores, labels):
285
+ x0, y0, x1, y1 = b
286
+ nm = names[int(c)] if 0 <= int(c) < len(names) else str(int(c))
287
+ counts[nm] = counts.get(nm, 0) + 1
288
+ logger.info(f" {nm:8s} score={float(s):.3f} xyxy=({x0:.1f},{y0:.1f},{x1:.1f},{y1:.1f})")
289
+ logger.info(f"per-class: {counts}")
290
+
291
+ if opt.img_save_path:
292
+ vis = draw(img, boxes, scores, labels, names, DEFAULT_COLORS)
293
+ os.makedirs(os.path.dirname(os.path.abspath(opt.img_save_path)) or ".", exist_ok=True)
294
+ cv2.imwrite(opt.img_save_path, vis)
295
+ logger.info(f"Saved to {opt.img_save_path}")
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
result_aquarium_yolov8.jpg ADDED

Git LFS Details

  • SHA256: d2a6a603a914b2618b178fe6a3c53c2a0bba7ecfc59ee13be763bd32382739d5
  • Pointer size: 131 Bytes
  • Size of remote file: 400 kB
test.png ADDED

Git LFS Details

  • SHA256: fd9201f6f48e57d3fcf8d099d9cfdec52e7333c65f3a9e1992cef4aaa13c29b1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB