summaryrefslogtreecommitdiffstats
path: root/lib/oeqa/runtime/files/dldt-inference-engine/classification_sample.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/oeqa/runtime/files/dldt-inference-engine/classification_sample.py')
-rw-r--r--lib/oeqa/runtime/files/dldt-inference-engine/classification_sample.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/lib/oeqa/runtime/files/dldt-inference-engine/classification_sample.py b/lib/oeqa/runtime/files/dldt-inference-engine/classification_sample.py
new file mode 100644
index 00000000..1906e9fe
--- /dev/null
+++ b/lib/oeqa/runtime/files/dldt-inference-engine/classification_sample.py
@@ -0,0 +1,135 @@
1#!/usr/bin/env python3
2"""
3 Copyright (C) 2018-2019 Intel Corporation
4
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8
9 http://www.apache.org/licenses/LICENSE-2.0
10
11 Unless required by applicable law or agreed to in writing, software
12 distributed under the License is distributed on an "AS IS" BASIS,
13 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 See the License for the specific language governing permissions and
15 limitations under the License.
16"""
17from __future__ import print_function
18import sys
19import os
20from argparse import ArgumentParser, SUPPRESS
21import cv2
22import numpy as np
23import logging as log
24from time import time
25from openvino.inference_engine import IENetwork, IECore
26
27
28def build_argparser():
29 parser = ArgumentParser(add_help=False)
30 args = parser.add_argument_group('Options')
31 args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
32 args.add_argument("-m", "--model", help="Required. Path to an .xml file with a trained model.", required=True,
33 type=str)
34 args.add_argument("-i", "--input", help="Required. Path to a folder with images or path to an image files",
35 required=True,
36 type=str, nargs="+")
37 args.add_argument("-l", "--cpu_extension",
38 help="Optional. Required for CPU custom layers. "
39 "MKLDNN (CPU)-targeted custom layers. Absolute path to a shared library with the"
40 " kernels implementations.", type=str, default=None)
41 args.add_argument("-d", "--device",
42 help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL, MYRIAD or HETERO: is "
43 "acceptable. The sample will look for a suitable plugin for device specified. Default "
44 "value is CPU",
45 default="CPU", type=str)
46 args.add_argument("--labels", help="Optional. Path to a labels mapping file", default=None, type=str)
47 args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
48
49 return parser
50
51
52def main():
53 log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
54 args = build_argparser().parse_args()
55 model_xml = args.model
56 model_bin = os.path.splitext(model_xml)[0] + ".bin"
57
58 # Plugin initialization for specified device and load extensions library if specified
59 log.info("Creating Inference Engine")
60 ie = IECore()
61 if args.cpu_extension and 'CPU' in args.device:
62 ie.add_extension(args.cpu_extension, "CPU")
63 # Read IR
64 log.info("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin))
65 net = IENetwork(model=model_xml, weights=model_bin)
66
67 if "CPU" in args.device:
68 supported_layers = ie.query_network(net, "CPU")
69 not_supported_layers = [l for l in net.layers.keys() if l not in supported_layers]
70 if len(not_supported_layers) != 0:
71 log.error("Following layers are not supported by the plugin for specified device {}:\n {}".
72 format(args.device, ', '.join(not_supported_layers)))
73 log.error("Please try to specify cpu extensions library path in sample's command line parameters using -l "
74 "or --cpu_extension command line argument")
75 sys.exit(1)
76
77 assert len(net.inputs.keys()) == 1, "Sample supports only single input topologies"
78 assert len(net.outputs) == 1, "Sample supports only single output topologies"
79
80 log.info("Preparing input blobs")
81 input_blob = next(iter(net.inputs))
82 out_blob = next(iter(net.outputs))
83 net.batch_size = len(args.input)
84
85 # Read and pre-process input images
86 n, c, h, w = net.inputs[input_blob].shape
87 images = np.ndarray(shape=(n, c, h, w))
88 for i in range(n):
89 image = cv2.imread(args.input[i])
90 if image.shape[:-1] != (h, w):
91 log.warning("Image {} is resized from {} to {}".format(args.input[i], image.shape[:-1], (h, w)))
92 image = cv2.resize(image, (w, h))
93 image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
94 images[i] = image
95 log.info("Batch size is {}".format(n))
96
97 # Loading model to the plugin
98 log.info("Loading model to the plugin")
99 exec_net = ie.load_network(network=net, device_name=args.device)
100
101 # Start sync inference
102 log.info("Starting inference in synchronous mode")
103 res = exec_net.infer(inputs={input_blob: images})
104
105 # Processing output blob
106 log.info("Processing output blob")
107 res = res[out_blob]
108 log.info("Top {} results: ".format(args.number_top))
109 if args.labels:
110 with open(args.labels, 'r') as f:
111 labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
112 else:
113 labels_map = None
114 classid_str = "classid"
115 probability_str = "probability"
116 for i, probs in enumerate(res):
117 probs = np.squeeze(probs)
118 top_ind = np.argsort(probs)[-args.number_top:][::-1]
119 print("Image {}\n".format(args.input[i]))
120 print(classid_str, probability_str)
121 print("{} {}".format('-' * len(classid_str), '-' * len(probability_str)))
122 for id in top_ind:
123 det_label = labels_map[id] if labels_map else "{}".format(id)
124 label_length = len(det_label)
125 space_num_before = (len(classid_str) - label_length) // 2
126 space_num_after = len(classid_str) - (space_num_before + label_length) + 2
127 space_num_before_prob = (len(probability_str) - len(str(probs[id]))) // 2
128 print("{}{}{}{}{:.7f}".format(' ' * space_num_before, det_label,
129 ' ' * space_num_after, ' ' * space_num_before_prob,
130 probs[id]))
131 print("\n")
132 log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
133
134if __name__ == '__main__':
135 sys.exit(main() or 0)