|
| 1 | +import os |
1 | 2 | import numpy as np |
2 | 3 | import torch |
| 4 | +import json |
| 5 | +import subprocess |
3 | 6 | from PIL import Image |
4 | 7 | from deepsense import neptune |
5 | 8 | from torch.autograd import Variable |
| 9 | +from tempfile import TemporaryDirectory |
6 | 10 |
|
7 | | -from steps.pytorch.callbacks import NeptuneMonitor |
8 | | -from utils import softmax, categorize_image |
9 | | -from pipeline_config import CATEGORY_IDS |
| 11 | +import postprocessing as post |
| 12 | +from steps.base import Step, Dummy |
| 13 | +from steps.utils import get_logger |
| 14 | +from steps.pytorch.callbacks import NeptuneMonitor, ValidationMonitor |
| 15 | +from utils import softmax, categorize_image, coco_evaluation, create_annotations |
| 16 | +from pipeline_config import CATEGORY_IDS, Y_COLUMNS_SCORING |
| 17 | + |
| 18 | +logger = get_logger() |
10 | 19 |
|
11 | 20 |
|
12 | 21 | class NeptuneMonitorSegmentation(NeptuneMonitor): |
@@ -91,3 +100,140 @@ def get_prediction_masks(self): |
91 | 100 | prediction_masks[mask_key] = np.stack([prediction, channel_ground_truth], axis=1) |
92 | 101 | break |
93 | 102 | return prediction_masks |
| 103 | + |
| 104 | + |
| 105 | +class ValidationMonitorSegmentation(ValidationMonitor): |
| 106 | + def __init__(self, data_dir, small_annotations_size, validate_with_map=False, *args, **kwargs): |
| 107 | + super().__init__(*args, **kwargs) |
| 108 | + self.data_dir = data_dir |
| 109 | + self.small_annotations_size = small_annotations_size |
| 110 | + self.validate_with_map = validate_with_map |
| 111 | + self.validation_pipeline = postprocessing__pipeline_simplified |
| 112 | + self.validation_loss = None |
| 113 | + self.meta_valid = None |
| 114 | + |
| 115 | + def set_params(self, transformer, validation_datagen, meta_valid=None, *args, **kwargs): |
| 116 | + self.model = transformer.model |
| 117 | + self.optimizer = transformer.optimizer |
| 118 | + self.loss_function = transformer.loss_function |
| 119 | + self.output_names = transformer.output_names |
| 120 | + self.validation_datagen = validation_datagen |
| 121 | + self.meta_valid = meta_valid |
| 122 | + self.validation_loss = transformer.validation_loss |
| 123 | + |
| 124 | + def get_validation_loss(self): |
| 125 | + if self.validate_with_map: |
| 126 | + return self._get_validation_loss() |
| 127 | + else: |
| 128 | + return super().get_validation_loss() |
| 129 | + |
| 130 | + def _get_validation_loss(self): |
| 131 | + with TemporaryDirectory() as temp_dir: |
| 132 | + outputs = self._transform() |
| 133 | + prediction = self._generate_prediction(temp_dir, outputs) |
| 134 | + if len(prediction) == 0: |
| 135 | + return self.validation_loss.setdefault(self.epoch_id, {'sum': Variable(torch.Tensor([0]))}) |
| 136 | + |
| 137 | + prediction_filepath = os.path.join(temp_dir, 'prediction.json') |
| 138 | + with open(prediction_filepath, "w") as fp: |
| 139 | + fp.write(json.dumps(prediction)) |
| 140 | + |
| 141 | + annotation_file_path = os.path.join(self.data_dir, 'val', "annotation.json") |
| 142 | + |
| 143 | + logger.info('Calculating mean precision and recall') |
| 144 | + average_precision, average_recall = coco_evaluation(gt_filepath=annotation_file_path, |
| 145 | + prediction_filepath=prediction_filepath, |
| 146 | + image_ids=self.meta_valid[Y_COLUMNS_SCORING].values, |
| 147 | + category_ids=CATEGORY_IDS[1:], |
| 148 | + small_annotations_size=self.small_annotations_size) |
| 149 | + return self.validation_loss.setdefault(self.epoch_id, {'sum': Variable(torch.Tensor([average_precision]))}) |
| 150 | + |
| 151 | + def _transform(self): |
| 152 | + self.model.eval() |
| 153 | + batch_gen, steps = self.validation_datagen |
| 154 | + outputs = {} |
| 155 | + for batch_id, data in enumerate(batch_gen): |
| 156 | + if isinstance(data, list): |
| 157 | + X = data[0] |
| 158 | + else: |
| 159 | + X = data |
| 160 | + |
| 161 | + if torch.cuda.is_available(): |
| 162 | + X = Variable(X, volatile=True).cuda() |
| 163 | + else: |
| 164 | + X = Variable(X, volatile=True) |
| 165 | + |
| 166 | + outputs_batch = self.model(X) |
| 167 | + if len(self.output_names) == 1: |
| 168 | + outputs.setdefault(self.output_names[0], []).append(outputs_batch.data.cpu().numpy()) |
| 169 | + else: |
| 170 | + for name, output in zip(self.output_names, outputs_batch): |
| 171 | + output_ = output.data.cpu().numpy() |
| 172 | + outputs.setdefault(name, []).append(output_) |
| 173 | + if batch_id == steps: |
| 174 | + break |
| 175 | + self.model.train() |
| 176 | + outputs = {'{}_prediction'.format(name): np.vstack(outputs_) for name, outputs_ in outputs.items()} |
| 177 | + for name, prediction in outputs.items(): |
| 178 | + outputs[name] = softmax(prediction, axis=1) |
| 179 | + |
| 180 | + return outputs |
| 181 | + |
| 182 | + def _generate_prediction(self, cache_dirpath, outputs): |
| 183 | + data = {'callback_input': {'meta': self.meta_valid, |
| 184 | + 'meta_valid': None, |
| 185 | + 'target_sizes': [(300, 300)] * len(self.meta_valid), |
| 186 | + }, |
| 187 | + 'unet_output': {**outputs} |
| 188 | + } |
| 189 | + |
| 190 | + pipeline = self.validation_pipeline(cache_dirpath) |
| 191 | + for step_name in pipeline.all_steps: |
| 192 | + cmd = 'touch {}'.format(os.path.join(cache_dirpath, 'transformers', step_name)) |
| 193 | + subprocess.call(cmd, shell=True) |
| 194 | + output = pipeline.transform(data) |
| 195 | + y_pred = output['y_pred'] |
| 196 | + |
| 197 | + prediction = create_annotations(self.meta_valid, y_pred, logger, CATEGORY_IDS) |
| 198 | + return prediction |
| 199 | + |
| 200 | + |
| 201 | +def postprocessing__pipeline_simplified(cache_dirpath): |
| 202 | + mask_resize = Step(name='mask_resize', |
| 203 | + transformer=post.Resizer(), |
| 204 | + input_data=['unet_output', 'callback_input'], |
| 205 | + adapter={'images': ([('unet_output', 'multichannel_map_prediction')]), |
| 206 | + 'target_sizes': ([('callback_input', 'target_sizes')]), |
| 207 | + }, |
| 208 | + cache_dirpath=cache_dirpath) |
| 209 | + |
| 210 | + category_mapper = Step(name='category_mapper', |
| 211 | + transformer=post.CategoryMapper(), |
| 212 | + input_steps=[mask_resize], |
| 213 | + adapter={'images': ([('mask_resize', 'resized_images')]), |
| 214 | + }, |
| 215 | + cache_dirpath=cache_dirpath) |
| 216 | + |
| 217 | + labeler = Step(name='labeler', |
| 218 | + transformer=post.MulticlassLabeler(), |
| 219 | + input_steps=[category_mapper], |
| 220 | + adapter={'images': ([(category_mapper.name, 'categorized_images')]), |
| 221 | + }, |
| 222 | + cache_dirpath=cache_dirpath) |
| 223 | + |
| 224 | + score_builder = Step(name='score_builder', |
| 225 | + transformer=post.ScoreBuilder(), |
| 226 | + input_steps=[labeler, mask_resize], |
| 227 | + adapter={'images': ([(labeler.name, 'labeled_images')]), |
| 228 | + 'probabilities': ([(mask_resize.name, 'resized_images')]), |
| 229 | + }, |
| 230 | + cache_dirpath=cache_dirpath) |
| 231 | + |
| 232 | + output = Step(name='output', |
| 233 | + transformer=Dummy(), |
| 234 | + input_steps=[score_builder], |
| 235 | + adapter={'y_pred': ([(score_builder.name, 'images_with_scores')]), |
| 236 | + }, |
| 237 | + cache_dirpath=cache_dirpath) |
| 238 | + |
| 239 | + return output |
0 commit comments