Skip to content

Commit 905e9e5

Browse files
michal2409asulecki
authored andcommitted
[nnUnet/PyT] Add support for Triton
1 parent 2a2735f commit 905e9e5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+32111
-123
lines changed

PyTorch/Segmentation/nnUNet/Dockerfile

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ WORKDIR /workspace/nnunet_pyt
66

77
RUN pip install --upgrade pip
88
RUN pip install --disable-pip-version-check -r requirements.txt
9+
RUN pip install --disable-pip-version-check -r triton/requirements.txt
910
RUN pip install pytorch-lightning==1.0.0 --no-dependencies
1011
RUN pip install monai==0.4.0 --no-dependencies
1112
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.30.0
@@ -14,3 +15,10 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2
1415
RUN unzip -qq awscliv2.zip
1516
RUN ./aws/install
1617
RUN rm -rf awscliv2.zip aws
18+
19+
# Install Perf Client required library
20+
RUN apt-get update && apt-get install -y libb64-dev libb64-0d
21+
22+
# Install Triton Client Python API and copy Perf Client
23+
#COPY --from=triton-client /workspace/install/ /workspace/install/
24+
#RUN pip install /workspace/install/python/triton*.whl

PyTorch/Segmentation/nnUNet/README.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by defaul
134134

135135
Test time augmentation is an inference technique which averages predictions from augmented images with its prediction. As a result, predictions are more accurate, but with the cost of slower inference process. For nnU-Net, we use all possible flip combinations for image augmenting. Test time augmentation can be enabled by adding the `--tta` flag.
136136

137-
**Deep supervision**
138-
139-
Deep supervision is a technique which adds auxiliary loss in U-Net decoder. For nnU-Net, we add auxiliary losses to all but the lowest two decoder levels. Final loss is the weighted average of losses. Deep supervision can be enabled by adding the `--deep_supervision` flag.
140-
141137
## Setup
142138

143139
The following section lists the requirements that you need to meet in order to start training the nnU-Net model.
@@ -308,7 +304,7 @@ To see the full list of available options and their descriptions, use the `-h` o
308304
The following example output is printed when running the model:
309305

310306
```
311-
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--deep_supervision] [--drop_block] [--attention] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,radam,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
307+
usage: main.py [-h] [--exec_mode {train,evaluate,predict}] [--data DATA] [--results RESULTS] [--logname LOGNAME] [--task TASK] [--gpus GPUS] [--learning_rate LEARNING_RATE] [--gradient_clip_val GRADIENT_CLIP_VAL] [--negative_slope NEGATIVE_SLOPE] [--tta] [--amp] [--benchmark] [--residual] [--focal] [--sync_batchnorm] [--save_ckpt] [--nfolds NFOLDS] [--seed SEED] [--skip_first_n_eval SKIP_FIRST_N_EVAL] [--ckpt_path CKPT_PATH] [--fold FOLD] [--patience PATIENCE] [--lr_patience LR_PATIENCE] [--batch_size BATCH_SIZE] [--val_batch_size VAL_BATCH_SIZE] [--steps STEPS [STEPS ...]] [--profile] [--momentum MOMENTUM] [--weight_decay WEIGHT_DECAY] [--save_preds] [--dim {2,3}] [--resume_training] [--factor FACTOR] [--num_workers NUM_WORKERS] [--min_epochs MIN_EPOCHS] [--max_epochs MAX_EPOCHS] [--warmup WARMUP] [--norm {instance,batch,group}] [--nvol NVOL] [--data2d_dim {2,3}] [--oversampling OVERSAMPLING] [--overlap OVERLAP] [--affinity {socket,single,single_unique,socket_unique_interleaved,socket_unique_continuous,disabled}] [--scheduler {none,multistep,cosine,plateau}] [--optimizer {sgd,radam,adam}] [--blend {gaussian,constant}] [--train_batches TRAIN_BATCHES] [--test_batches TEST_BATCHES]
312308
313309
optional arguments:
314310
-h, --help show this help message and exit
@@ -328,9 +324,6 @@ optional arguments:
328324
--tta Enable test time augmentation (default: False)
329325
--amp Enable automatic mixed precision (default: False)
330326
--benchmark Run model benchmarking (default: False)
331-
--deep_supervision Enable deep supervision (default: False)
332-
--drop_block Enable drop block (default: False)
333-
--attention Enable attention in decoder (default: False)
334327
--residual Enable residual block in encoder (default: False)
335328
--focal Use focal loss instead of cross entropy (default: False)
336329
--sync_batchnorm Enable synchronized batchnorm (default: False)
@@ -435,7 +428,7 @@ The default configuration minimizes a function `L = (1 - dice_coefficient) + cro
435428
The training can be run directly without using the predefined scripts. The name of the training script is `main.py`. For example:
436429

437430
```
438-
python main.py --exec_mode train --task 01 --fold 0 --gpus 1 --amp --deep_supervision
431+
python main.py --exec_mode train --task 01 --fold 0 --gpus 1 --amp
439432
```
440433

441434
Training artifacts will be saved to `/results` in the container. Some important artifacts are:
@@ -612,7 +605,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre
612605

613606
FP16
614607

615-
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
608+
| Dimension | Batch size |Resolution| Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
616609
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
617610
| 2 | 64 | 4x192x160 | 1866.52 | 34.29 | 34.7 | 48.87 | 52.44 |
618611
| 2 | 128 | 4x192x160 | 2032.74 | 62.97 | 63.21 | 63.25 | 63.32 |
@@ -622,7 +615,7 @@ FP16
622615

623616
FP32
624617

625-
| Dimension | Batch size | Resolution | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
618+
| Dimension | Batch size |Resolution| Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
626619
|:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
627620
| 2 | 64 | 4x192x160 | 1051.46 | 60.87 | 61.21 | 61.48 | 62.87 |
628621
| 2 | 128 | 4x192x160 | 1051.68 | 121.71 | 122.29 | 122.44 | 122.6 |
@@ -638,6 +631,10 @@ To achieve these same results, follow the steps in the [Quick Start Guide](#quic
638631

639632
### Changelog
640633

634+
May 2021
635+
- Add Triton Inference Server support
636+
- Removed deep supervision, attention and drop block
637+
641638
March 2021
642639
- Container updated to 21.02
643640
- Change data format from tfrecord to npy and data loading for 2D

PyTorch/Segmentation/nnUNet/data_loading/dali_loader.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,37 @@ def define_graph(self):
160160
return img, lbl
161161

162162

163+
class BermudaPipeline(Pipeline):
164+
def __init__(self, batch_size, num_threads, device_id, **kwargs):
165+
super(BermudaPipeline, self).__init__(batch_size, num_threads, device_id)
166+
self.input_x = get_numpy_reader(
167+
files=kwargs["imgs"],
168+
shard_id=device_id,
169+
num_shards=kwargs["gpus"],
170+
seed=kwargs["seed"],
171+
shuffle=False,
172+
)
173+
self.input_y = get_numpy_reader(
174+
files=kwargs["lbls"],
175+
shard_id=device_id,
176+
num_shards=kwargs["gpus"],
177+
seed=kwargs["seed"],
178+
shuffle=False,
179+
)
180+
self.patch_size = kwargs["patch_size"]
181+
182+
def crop_fn(self, img, lbl):
183+
img = fn.crop(img, crop=self.patch_size, out_of_bounds_policy="pad")
184+
lbl = fn.crop(lbl, crop=self.patch_size, out_of_bounds_policy="pad")
185+
return img, lbl
186+
187+
def define_graph(self):
188+
img, lbl = self.input_x(name="ReaderX"), self.input_y(name="ReaderY")
189+
img, lbl = fn.reshape(img, layout="CDHW"), fn.reshape(lbl, layout="CDHW")
190+
img, lbl = self.crop_fn(img, lbl)
191+
return img, lbl
192+
193+
163194
class TestPipeline(Pipeline):
164195
def __init__(self, batch_size, num_threads, device_id, **kwargs):
165196
super(TestPipeline, self).__init__(batch_size, num_threads, device_id)
@@ -249,11 +280,6 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
249280
nbs *= batch_size
250281
imgs = list(itertools.chain(*(100 * [imgs])))[: nbs * kwargs["gpus"]]
251282
lbls = list(itertools.chain(*(100 * [lbls])))[: nbs * kwargs["gpus"]]
252-
if mode == "eval":
253-
reminder = len(imgs) % kwargs["gpus"]
254-
if reminder != 0:
255-
imgs = imgs[:-reminder]
256-
lbls = lbls[:-reminder]
257283

258284
pipe_kwargs = {
259285
"imgs": imgs,
@@ -284,6 +310,10 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
284310
pipeline = EvalPipeline
285311
output_map = ["image", "label"]
286312
dynamic_shape = True
313+
elif mode == "bermuda":
314+
pipeline = BermudaPipeline
315+
output_map = ["image", "label"]
316+
dynamic_shape = False
287317
else:
288318
pipeline = TestPipeline
289319
output_map = ["image", "meta"]

PyTorch/Segmentation/nnUNet/models/layers.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import numpy as np
1616
import torch
1717
import torch.nn as nn
18-
from dropblock import DropBlock3D, LinearScheduler
1918

2019
normalizations = {
2120
"instancenorm3d": nn.InstanceNorm3d,
@@ -68,30 +67,16 @@ def get_output_padding(kernel_size, stride, padding):
6867
return out_padding if len(out_padding) > 1 else out_padding[0]
6968

7069

71-
def get_drop_block():
72-
return LinearScheduler(
73-
DropBlock3D(block_size=5, drop_prob=0.0),
74-
start_value=0.0,
75-
stop_value=0.1,
76-
nr_steps=10000,
77-
)
78-
7970

8071
class ConvLayer(nn.Module):
8172
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
8273
super(ConvLayer, self).__init__()
8374
self.conv = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
8475
self.norm = get_norm(kwargs["norm"], out_channels)
8576
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
86-
self.use_drop_block = kwargs["drop_block"]
87-
if self.use_drop_block:
88-
self.drop_block = get_drop_block()
8977

9078
def forward(self, data):
9179
out = self.conv(data)
92-
if self.use_drop_block:
93-
self.drop_block.step()
94-
out = self.drop_block(out)
9580
out = self.norm(out)
9681
out = self.lrelu(out)
9782
return out
@@ -116,10 +101,6 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
116101
self.conv2 = get_conv(out_channels, out_channels, kernel_size, 1, kwargs["dim"])
117102
self.norm = get_norm(kwargs["norm"], out_channels)
118103
self.lrelu = nn.LeakyReLU(negative_slope=kwargs["negative_slope"], inplace=True)
119-
self.use_drop_block = kwargs["drop_block"]
120-
if self.use_drop_block:
121-
self.drop_block = get_drop_block()
122-
self.skip_drop_block = get_drop_block()
123104
self.downsample = None
124105
if max(stride) > 1 or in_channels != out_channels:
125106
self.downsample = get_conv(in_channels, out_channels, kernel_size, stride, kwargs["dim"])
@@ -129,52 +110,22 @@ def forward(self, input_data):
129110
residual = input_data
130111
out = self.conv1(input_data)
131112
out = self.conv2(out)
132-
if self.use_drop_block:
133-
out = self.drop_block(out)
134113
out = self.norm(out)
135114
if self.downsample is not None:
136115
residual = self.downsample(residual)
137-
if self.use_drop_block:
138-
residual = self.skip_drop_block(residual)
139116
residual = self.norm_res(residual)
140117
out = self.lrelu(out + residual)
141118
return out
142119

143120

144-
class AttentionLayer(nn.Module):
145-
def __init__(self, in_channels, out_channels, norm, dim):
146-
super(AttentionLayer, self).__init__()
147-
self.conv = get_conv(in_channels, out_channels, kernel_size=3, stride=1, dim=dim)
148-
self.norm = get_norm(norm, out_channels)
149-
150-
def forward(self, inputs):
151-
out = self.conv(inputs)
152-
out = self.norm(out)
153-
return out
154-
155-
156121
class UpsampleBlock(nn.Module):
157122
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
158123
super(UpsampleBlock, self).__init__()
159124
self.transp_conv = get_transp_conv(in_channels, out_channels, stride, stride, kwargs["dim"])
160125
self.conv_block = ConvBlock(2 * out_channels, out_channels, kernel_size, 1, **kwargs)
161-
self.attention = kwargs["attention"]
162-
if self.attention:
163-
att_out, norm, dim = out_channels // 2, kwargs["norm"], kwargs["dim"]
164-
self.conv_o = AttentionLayer(out_channels, att_out, norm, dim)
165-
self.conv_s = AttentionLayer(out_channels, att_out, norm, dim)
166-
self.psi = AttentionLayer(att_out, 1, norm, dim)
167-
self.sigmoid = nn.Sigmoid()
168-
self.relu = nn.ReLU(inplace=True)
169126

170127
def forward(self, input_data, skip_data):
171128
out = self.transp_conv(input_data)
172-
if self.attention:
173-
out_a = self.conv_o(out)
174-
skip_a = self.conv_s(skip_data)
175-
psi_a = self.psi(self.relu(out_a + skip_a))
176-
attention = self.sigmoid(psi_a)
177-
skip_data = skip_data * attention
178129
out = torch.cat((out, skip_data), dim=1)
179130
out = self.conv_block(out)
180131
return out

PyTorch/Segmentation/nnUNet/models/nn_unet.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,33 @@
3939

4040

4141
class NNUnet(pl.LightningModule):
42-
def __init__(self, args):
42+
def __init__(self, args, bermuda=False, data_dir=None):
4343
super(NNUnet, self).__init__()
4444
self.args = args
45-
if not hasattr(self.args, "drop_block"): # For backward compability
46-
self.args.drop_block = False
45+
self.bermuda = bermuda
46+
if data_dir is not None:
47+
self.args.data = data_dir
4748
self.save_hyperparameters()
4849
self.build_nnunet()
49-
self.loss = Loss(self.args.focal)
50-
self.dice = Dice(self.n_class)
5150
self.best_sum = 0
5251
self.best_sum_epoch = 0
5352
self.best_dice = self.n_class * [0]
5453
self.best_epoch = self.n_class * [0]
5554
self.best_sum_dice = self.n_class * [0]
56-
self.learning_rate = args.learning_rate
57-
self.tta_flips = get_tta_flips(args.dim)
5855
self.test_idx = 0
5956
self.test_imgs = []
60-
if self.args.exec_mode in ["train", "evaluate"]:
61-
self.dllogger = get_dllogger(args.results)
57+
if not self.bermuda:
58+
self.learning_rate = args.learning_rate
59+
self.loss = Loss(self.args.focal)
60+
self.tta_flips = get_tta_flips(args.dim)
61+
self.dice = Dice(self.n_class)
62+
if self.args.exec_mode in ["train", "evaluate"]:
63+
self.dllogger = get_dllogger(args.results)
6264

6365
def forward(self, img):
66+
return torch.argmax(self.model(img), 1)
67+
68+
def _forward(self, img):
6469
if self.args.benchmark:
6570
if self.args.dim == 2 and self.args.data2d_dim == 3:
6671
img = layout_2d(img, None)
@@ -70,14 +75,14 @@ def forward(self, img):
7075
def training_step(self, batch, batch_idx):
7176
img, lbl = self.get_train_data(batch)
7277
pred = self.model(img)
73-
loss = self.compute_loss(pred, lbl)
78+
loss = self.loss(pred, lbl)
7479
return loss
7580

7681
def validation_step(self, batch, batch_idx):
7782
if self.current_epoch < self.args.skip_first_n_eval:
7883
return None
7984
img, lbl = batch["image"], batch["label"]
80-
pred = self.forward(img)
85+
pred = self._forward(img)
8186
loss = self.loss(pred, lbl)
8287
self.dice.update(pred, lbl[:, 0])
8388
return {"val_loss": loss}
@@ -86,7 +91,7 @@ def test_step(self, batch, batch_idx):
8691
if self.args.exec_mode == "evaluate":
8792
return self.validation_step(batch, batch_idx)
8893
img = batch["image"]
89-
pred = self.forward(img)
94+
pred = self._forward(img)
9095
if self.args.save_preds:
9196
meta = batch["meta"][0].cpu().detach().numpy()
9297
original_shape = meta[2]
@@ -120,25 +125,12 @@ def build_nnunet(self):
120125
strides=strides,
121126
dimension=self.args.dim,
122127
residual=self.args.residual,
123-
attention=self.args.attention,
124-
drop_block=self.args.drop_block,
125128
normalization_layer=self.args.norm,
126129
negative_slope=self.args.negative_slope,
127-
deep_supervision=self.args.deep_supervision,
128130
)
129131
if is_main_process():
130132
print(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}")
131133

132-
def compute_loss(self, preds, label):
133-
if self.args.deep_supervision:
134-
loss = self.loss(preds[0], label)
135-
for i, pred in enumerate(preds[1:]):
136-
downsampled_label = nn.functional.interpolate(label, pred.shape[2:])
137-
loss += 0.5 ** (i + 1) * self.loss(pred, downsampled_label)
138-
c_norm = 1 / (2 - 2 ** (-len(preds)))
139-
return c_norm * loss
140-
return self.loss(preds, label)
141-
142134
def do_inference(self, image):
143135
if self.args.dim == 3:
144136
return self.sliding_window_inference(image)

0 commit comments

Comments
 (0)