Segmentation models is python library with Neural Networks for Image Segmentation based on Keras (Tensorflow) framework.
The main features of this library are:
- High level API (just two lines to create NN)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 25 available backbones for each architecture
- All backbones have pre-trained weights for faster and better convergence
Since the library is built on the Keras framework, created segmentaion model is just a Keras Model, which can be created as easy as:
from segmentation_models import Unet model = Unet()Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = Unet('resnet34', encoder_weights='imagenet')Change number of output classes in the model (choose your case):
# binary segmentation (this parameters are default when you call Unet('resnet34') model = Unet('resnet34', classes=1, activation='sigmoid')# multiclass segmentation with non overlapping class masks (your classes + background) model = Unet('resnet34', classes=3, activation='softmax')# multiclass segmentation with independent overlapping/non-overlapping class masks model = Unet('resnet34', classes=3, activation='sigmoid')Change input shape of the model:
# if you set input channels not equal to 3, you have to set encoder_weights=None # how to handle such case with encoder_weights='imagenet' described in docs model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)from segmentation_models import Unet from segmentation_models.backbones import get_preprocessing from segmentation_models.losses import bce_jaccard_loss from segmentation_models.metrics import iou_score BACKBONE = 'resnet34' preprocess_input = get_preprocessing(BACKBONE) # load your data x_train, y_train, x_val, y_val = load_data(...) # preprocess input x_train = preprocess_input(x_train) x_val = preprocess_input(x_val) # define model model = Unet(BACKBONE, encoder_weights='imagenet') model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score]) # fit model # if you use data generator use model.fit_generator(...) instead of model.fit(...) # more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator model.fit( x=x_train, y=y_train, batch_size=16, epochs=100, validation_data=(x_val, y_val), )Same manimulations can be done with Linknet, PSPNet and FPN. For more detailed information about models API and use cases Read the Docs.
Models
| Unet | Linknet |
|---|---|
![]() | ![]() |
| PSPNet | FPN |
|---|---|
![]() | ![]() |
Backbones
| Type | Names |
|---|---|
| VGG | 'vgg16' 'vgg19' |
| ResNet | 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' |
| SE-ResNet | 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152' |
| ResNeXt | 'resnext50' 'resnext101' |
| SE-ResNeXt | 'seresnext50' 'seresnext101' |
| SENet154 | 'senet154' |
| DenseNet | 'densenet121' 'densenet169' 'densenet201' |
| Inception | 'inceptionv3' 'inceptionresnetv2' |
| MobileNet | 'mobilenet' 'mobilenetv2' |
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet'). Requirements
- Python 3.5+
- Keras >= 2.2.0
- Keras Application >= 1.0.7
- Image Classifiers == 0.2.0
- Tensorflow 1.9 (tested)
Pip package
$ pip install segmentation-modelsLatest version
$ pip install git+https://github.com/qubvel/segmentation_modelsLatest documentation is avaliable on Read the Docs
To see important changes between versions look at CHANGELOG.md
@misc{Yakubovskiy:2019, Author = {Pavel Yakubovskiy}, Title = {Segmentation Models}, Year = {2019}, Publisher = {GitHub}, Journal = {GitHub repository}, Howpublished = {\url{https://github.com/qubvel/segmentation_models}} } Project is distributed under MIT Licence.




