1+ import cv2 as cv
2+ from pathlib import Path
3+ import natsort
4+ import numpy as np
5+
6+ def one_hot_it (label , label_values ):
7+ """
8+ Convert a segmentation image label array to one-hot format
9+ by replacing each pixel value with a vector of length num_classes
10+
11+ # Arguments
12+ label: The 2D array segmentation image label
13+ label_values
14+
15+ # Returns
16+ A 2D array with the same width and hieght as the input, but
17+ with a depth size of num_classes
18+ """
19+ semantic_map = []
20+ for colour in label_values :
21+ # colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
22+ equality = np .equal (label , colour )
23+ class_map = np .all (equality , axis = - 1 )
24+ semantic_map .append (class_map )
25+ semantic_map = np .stack (semantic_map , axis = - 1 )
26+
27+ return semantic_map
28+
29+ def reverse_one_hot (image ):
30+ """
31+ Transform a 2D array in one-hot format (depth is num_classes),
32+ to a 2D array with only 1 channel, where each pixel value is
33+ the classified class key.
34+
35+ # Arguments
36+ image: The one-hot format image
37+
38+ # Returns
39+ A 2D array with the same width and hieght as the input, but
40+ with a depth size of 1, where each pixel value is the classified
41+ class key.
42+ """
43+ x = np .argmax (image , axis = - 1 )
44+ return x
45+
46+ def load_image (path ):
47+ image = cv .cvtColor (cv .imread (path , 1 ), cv .COLOR_BGR2RGB )
48+ return image
49+
50+ GT_Path = Path ("path-to-original-label-images" )
51+ GT_File = natsort .natsorted (list (GT_Path .glob ("*.png" )), alg = natsort .PATH )
52+ GT_Str = []
53+ for i in GT_File :
54+ GT_Str .append (str (i ))
55+
56+ out_prefix = "precoded_label"
57+ label_values = [[255 , 255 , 255 ], [0 , 0 , 255 ], [0 , 255 , 255 ], [0 , 255 , 0 ], [255 , 255 , 0 ], [255 , 0 , 0 ]]
58+ for k in range (len (GT_Str )):
59+ gt = load_image (GT_Str [k ])
60+ out = reverse_one_hot (one_hot_it (gt ,label_values ))
61+ out_str = out_prefix + Path (GT_Str [k ]).name
62+ cv .imwrite (out_str ,out )
63+ # print("kk")
0 commit comments