2i commited on
Commit
4f0f14d
·
verified ·
1 Parent(s): a8d8513

Update detection/base.py

Browse files
Files changed (1) hide show
  1. detection/base.py +6 -5
detection/base.py CHANGED
@@ -53,18 +53,18 @@ class ObjectDetection:
53
 
54
  def _gr_detect(self, image: ImageTyping, model_name: str,
55
  iou_threshold: float = 0.7, score_threshold: float = 0.25) \
56
- -> gr.AnnotatedImage:
57
  labels = self.get_labels(model_name=model_name)
58
  _colors = list(map(str, rnd_colors(len(labels))))
59
  _color_map = dict(zip(labels, _colors))
 
60
  return gr.AnnotatedImage(
61
  value=(image, [
62
- (_bbox_fix(bbox), label) for bbox, label, _ in
63
- self.detect(image, model_name, iou_threshold, score_threshold)
64
  ]),
65
  color_map=_color_map,
66
  label='Labeled',
67
- )
68
 
69
  def make_ui(self):
70
  with gr.Row():
@@ -82,6 +82,7 @@ class ObjectDetection:
82
 
83
  with gr.Column():
84
  gr_output_image = gr.AnnotatedImage(label="Labeled")
 
85
 
86
  gr_submit.click(
87
  self._gr_detect,
@@ -91,7 +92,7 @@ class ObjectDetection:
91
  gr_iou_threshold,
92
  gr_score_threshold,
93
  ],
94
- outputs=[gr_output_image],
95
  )
96
 
97
 
 
53
 
54
  def _gr_detect(self, image: ImageTyping, model_name: str,
55
  iou_threshold: float = 0.7, score_threshold: float = 0.25) \
56
+ -> Tuple[gr.AnnotatedImage, List[Tuple[Tuple[float, float, float, float], str, float]]]:
57
  labels = self.get_labels(model_name=model_name)
58
  _colors = list(map(str, rnd_colors(len(labels))))
59
  _color_map = dict(zip(labels, _colors))
60
+ detections = self.detect(image, model_name, iou_threshold, score_threshold)
61
  return gr.AnnotatedImage(
62
  value=(image, [
63
+ (_bbox_fix(bbox), label) for bbox, label, _ in detections
 
64
  ]),
65
  color_map=_color_map,
66
  label='Labeled',
67
+ ), detections
68
 
69
  def make_ui(self):
70
  with gr.Row():
 
82
 
83
  with gr.Column():
84
  gr_output_image = gr.AnnotatedImage(label="Labeled")
85
+ gr_detections = gr.JSON(label="Detections")
86
 
87
  gr_submit.click(
88
  self._gr_detect,
 
92
  gr_iou_threshold,
93
  gr_score_threshold,
94
  ],
95
+ outputs=[gr_output_image, gr_detections],
96
  )
97
 
98