Author: Humbulani Ndou
Date created: 2025/03/17
Last modified: 2025/03/17
Description: Gated Residual and Variable Selection Networks prediction with HyperParameters tuning.
View in Colab β’
GitHub source
The following example extends the script structured_data/classification_with_grn_and_vsn.py
by incorporating hyperparameters tuning using Autokeras and KerasTuner. Specifics regarding which APIs are used from the these two packages will be described in detail in the relevant code sections.
This example demonstrates the use of Gated Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by Bryan Lim et al. in Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting, for structured data classification. GRNs give the flexibility to the model to apply non-linear processing only where needed. VSNs allow the model to softly remove any unnecessary noisy inputs which could negatively impact performance. Together, those techniques help improving the learning capacity of deep neural network models.
Note that this example implements only the GRN and VSN components described in in the paper, rather than the whole TFT model, as GRN and VSN can be useful on their own for structured data learning tasks.
To run the code you need to use TensorFlow 2.3 or higher.
Our dataset is provided by the Cleveland Clinic Foundation for Heart Disease. It's a CSV file with 303 rows. Each row contains information about a patient (a sample), and each column describes an attribute of the patient (a feature). We use the features to predict whether a patient has a heart disease (binary classification).
Here's the description of each feature:
Column | Description | Feature Type |
---|---|---|
Age | Age in years | Numerical |
Sex | (1 = male; 0 = female) | Categorical |
CP | Chest pain type (0, 1, 2, 3, 4) | Categorical |
Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical |
Chol | Serum cholesterol in mg/dl | Numerical |
FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical |
RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical |
Thalach | Maximum heart rate achieved | Numerical |
Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical |
Oldpeak | ST depression induced by exercise relative to rest | Numerical |
Slope | Slope of the peak exercise ST segment | Numerical |
CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical |
Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical |
Target | Diagnosis of heart disease (1 = true; 0 = false) | Target |
import os import subprocess import tarfile import numpy as np import pandas as pd import tree from typing import Optional, Union os.environ["KERAS_BACKEND"] = "tensorflow" # or jax, or torch # Keras imports import keras from keras import layers # KerasTuner imports import keras_tuner from keras_tuner import HyperParameters # AutoKeras imports import autokeras as ak from autokeras.utils import utils, types
Let's download the data and load it into a Pandas dataframe:
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv" dataframe = pd.read_csv(file_url)
The dataset includes 303 samples with 14 columns per sample (13 features, plus the target label):
dataframe.shape
(303, 14)
Here's a preview of a few samples:
dataframe.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 |
1 | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
2 | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversible | 0 |
3 | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
4 | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
The last column, "target", indicates whether the patient has a heart disease (1) or not (0).
Let's split the data into a training and validation set:
val_dataframe = dataframe.sample(frac=0.2, random_state=1337) train_dataframe = dataframe.drop(val_dataframe.index) print( f"Using {len(train_dataframe)} samples for training " f"and {len(val_dataframe)} for validation" )
Using 242 samples for training and 61 for validation
Here, we define the metadata of the dataset that will be useful for reading and parsing the data into input features, and encoding the input features with respect to their types.
COLUMN_NAMES = [ "age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target", ] # Target feature name. TARGET_FEATURE_NAME = "target" # Numeric feature names. NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"] # Categorical features and their vocabulary lists. # Note that we add 'v=' as a prefix to all categorical feature values to make # sure that they are treated as strings. CATEGORICAL_FEATURES_WITH_VOCABULARY = { feature_name: sorted( [ # Integer categorcal must be int and string must be str value if dataframe[feature_name].dtype == "int64" else str(value) for value in list(dataframe[feature_name].unique()) ] ) for feature_name in COLUMN_NAMES if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME]) } # All features names. FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list( CATEGORICAL_FEATURES_WITH_VOCABULARY.keys() )
The following features are categorical features encoded as integers:
sex
cp
fbs
restecg
exang
ca
We will encode these features using one-hot encoding. We have two options here:
CategoryEncoding()
, which requires knowing the range of input values and will error on input outside the range.IntegerLookup()
which will build a lookup table for inputs and reserve an output index for unkown input values.For this example, we want a simple solution that will handle out of range inputs at inference, so we will use IntegerLookup()
.
We also have a categorical feature encoded as a string: thal
. We will create an index of all possible features and encode output using the StringLookup()
layer.
Finally, the following feature are continuous numerical features:
age
trestbps
chol
thalach
oldpeak
slope
For each of these features, we will use a Normalization()
layer to make sure the mean of each feature is 0 and its standard deviation is 1.
Below, we define a utility function to do the operations:
process
to one-hot encode string or integer categorical features.# Tensorflow required for tf.data.Dataset import tensorflow as tf # We process our datasets elements here (categorical) and convert them to indices to avoid this step # during model training since only tensorflow support strings. def encode_categorical(features, target): for f in features: if f in CATEGORICAL_FEATURES_WITH_VOCABULARY: # Create a lookup to convert a string values to an integer indices. # Since we are not using a mask token nor expecting any out of vocabulary # (oov) token, we set mask_token to None and num_oov_indices to 0. cls = ( layers.StringLookup if features[f].dtype == "string" else layers.IntegerLookup ) features[f] = cls( vocabulary=CATEGORICAL_FEATURES_WITH_VOCABULARY[f], mask_token=None, num_oov_indices=0, output_mode="binary", )(features[f]) # Change features from OrderedDict to Dict to match Inputs as they are Dict. return dict(features), target
Let's generate tf.data.Dataset
objects for each dataframe:
def dataframe_to_dataset(dataframe): dataframe = dataframe.copy() labels = dataframe.pop("target") ds = ( tf.data.Dataset.from_tensor_slices((dict(dataframe), labels)) .map(encode_categorical) .shuffle(buffer_size=len(dataframe)) ) return ds train_ds = dataframe_to_dataset(train_dataframe) val_ds = dataframe_to_dataset(val_dataframe)
Each Dataset
yields a tuple (input, target)
where input
is a dictionary of features and target
is the value 0
or 1
:
for x, y in train_ds.take(1): print("Input:", x) print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=37>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=120>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=215>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 0, 0])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=170>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([1, 0, 0, 0])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 1, 0])>} Target: tf.Tensor(0, shape=(), dtype=int64)
Let's batch the datasets:
train_ds = train_ds.batch(32) val_ds = val_ds.batch(32)
Here we subclass the Autokeras Graph
build
: we override this method to be able to handle model Inputs
passed as dictionaries. In structured data analysis Inputs are normally passed as dictionaries for each feature of interestclass Graph(ak.graph.Graph): def build(self, hp): """Build the HyperModel into a Keras Model.""" keras_nodes = {} keras_input_nodes = [] for node in self.inputs: node_id = self._node_to_id[node] input_node = node.build_node(hp) output_node = node.build(hp, input_node) keras_input_nodes.append(input_node) keras_nodes[node_id] = output_node for block in self.blocks: temp_inputs = ( { n.name: keras_nodes[self._node_to_id[n]] for n in block.inputs if isinstance(n, ak.Input) } if isinstance(block.inputs[0], ak.Input) else [keras_nodes[self._node_to_id[n]] for n in block.inputs] ) outputs = tree.flatten(block.build(hp, inputs=temp_inputs)) for n, o in zip(block.outputs, outputs): keras_nodes[self._node_to_id[n]] = o model = keras.models.Model( keras_input_nodes, [ keras_nodes[self._node_to_id[output_node]] for output_node in self.outputs ], ) return self._compile_keras_model(hp, model) def _compile_keras_model(self, hp, model): # Specify hyperparameters from compile(...) optimizer_name = hp.Choice( "optimizer", ["adam", "sgd"], default="adam", ) learning_rate = hp.Choice( "learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3 ) if optimizer_name == "adam": optimizer = keras.optimizers.Adam(learning_rate=learning_rate) elif optimizer_name == "sgd": optimizer = keras.optimizers.SGD(learning_rate=learning_rate) model.compile( optimizer=optimizer, metrics=self._get_metrics(), loss=self._get_loss(), ) return model
Input
Here we subclass the Autokeras Input node object and override the dtype attribute from None to a user supplied value. We also override the build_node
method to use user supplied name for Inputs layers.
class Input(ak.Input): def __init__(self, dtype, name=None, **kwargs): super().__init__(name=name, **kwargs) # Override dtype to a user dtype value self.dtype = dtype self.name = name def build_node(self, hp): return keras.Input(name=self.name, shape=self.shape, dtype=self.dtype)
Here we subclass Autokeras ClassificationHead and override the init method, and we add the method get_expected_shape
to infer the labels shape. We remove the preprocessing fuctionality as we prefer to conduct such manually.
class ClassifierHead(ak.ClassificationHead): def __init__( self, num_classes: Optional[int] = None, multi_label: bool = False, loss: Optional[types.LossType] = None, metrics: Optional[types.MetricsType] = None, dropout: Optional[float] = None, **kwargs, ): self.num_classes = num_classes self.multi_label = multi_label self.dropout = dropout if metrics is None: metrics = ["accuracy"] if loss is None: loss = self.infer_loss() ak.Head.__init__(self, loss=loss, metrics=metrics, **kwargs) self.shape = self.get_expected_shape() def get_expected_shape(self): # Compute expected shape from num_classes. if self.num_classes == 2 and not self.multi_label: return [1] return [self.num_classes]
This is a keras layer defined in the script structured_data/classification_with_grn_vsn.py
More details about this layer maybe found in the relevant script
class GatedLinearUnit(layers.Layer): def __init__(self, num_units, activation, **kwargs): super().__init__(**kwargs) self.linear = layers.Dense(num_units) self.sigmoid = layers.Dense(num_units, activation=activation) def call(self, inputs): return self.linear(inputs) * self.sigmoid(inputs) def build(self): self.built = True
This is a keras layer defined in the script structured_data/classification_with_grn_vsn.py
More details about this layer maybe found in the relevant script
class GatedResidualNetwork(layers.Layer): def __init__( self, num_units, dropout_rate, activation, use_layernorm=None, **kwargs ): super().__init__(**kwargs) self.num_units = num_units self.use_layernorm = use_layernorm self.elu_dense = layers.Dense(num_units, activation=activation) self.linear_dense = layers.Dense(num_units) self.dropout = layers.Dropout(dropout_rate) self.gated_linear_unit = GatedLinearUnit(num_units, activation) self.layer_norm = layers.LayerNormalization() self.project = layers.Dense(num_units) def call(self, inputs, hp): x = self.elu_dense(inputs) x = self.linear_dense(x) x = self.dropout(x) if inputs.shape[-1] != self.num_units: inputs = self.project(inputs) x = inputs + self.gated_linear_unit(x) use_layernorm = self.use_layernorm if use_layernorm is None: use_layernorm = hp.Boolean("use_layernorm", default=True) if use_layernorm: x = self.layer_norm(x) return x def build(self): self.built = True
VariableSelection Block
We have converted the following keras layer to an Autokeras Block to include hyperapameters to tune. Refer to Autokeras blocks API for writing custom Blocks.
class VariableSelection(ak.Block): def __init__( self, num_units: Optional[Union[int, HyperParameters.Choice]] = None, dropout_rate: Optional[Union[float, HyperParameters.Choice]] = None, activation: Optional[Union[str, HyperParameters.Choice]] = None, **kwargs, ): super().__init__(**kwargs) self.dropout = utils.get_hyperparameter( dropout_rate, HyperParameters().Choice("dropout", [0.0, 0.25, 0.5], default=0.0), float, ) self.num_units = utils.get_hyperparameter( num_units, HyperParameters().Choice( "num_units", [16, 32, 64, 128, 256, 512, 1024], default=16 ), int, ) self.activation = utils.get_hyperparameter( activation, HyperParameters().Choice( "vsn_activation", ["sigmoid", "elu"], default="sigmoid" ), str, ) def build(self, hp, inputs): num_units = utils.add_to_hp(self.num_units, hp, "num_units") dropout_rate = utils.add_to_hp(self.dropout, hp, "dropout_rate") activation = utils.add_to_hp(self.activation, hp, "activation") concat_inputs = [] # Project the features to 'num_units' dimension for input_ in inputs: if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY: concat_inputs.append( keras.layers.Dense(units=num_units)(inputs[input_]) ) else: # Create a Normalization layer for our feature normalizer = layers.Normalization() # Prepare a Dataset that only yields our feature feature_ds = train_ds.map(lambda x, y: x[input_]).map( lambda x: keras.ops.expand_dims(x, -1) ) # Learn the statistics of the data normalizer.adapt(feature_ds) # Normalize the input feature normal_feature = normalizer(inputs[input_]) concat_inputs.append( keras.layers.Dense(units=num_units)(normal_feature) ) v = layers.concatenate(concat_inputs) v = GatedResidualNetwork( num_units=num_units, dropout_rate=dropout_rate, activation=activation )(v, hp=hp) v = keras.ops.expand_dims( layers.Dense(units=len(inputs), activation=activation)(v), axis=-1 ) x = [] x += [ GatedResidualNetwork(num_units, dropout_rate, activation)(i, hp=hp) for i in concat_inputs ] x = keras.ops.stack(x, axis=1) return keras.ops.squeeze( keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1 )
# Categorical features have different shapes after the encoding, dependent on the # vocabulary or unique values of each feature. We create them accordinly to match the # input data elements generated by tf.data.Dataset after pre-processing them def create_model_inputs(): inputs = { f: ( Input( name=f, shape=(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[f]),), dtype="int64", ) if f in CATEGORICAL_FEATURES_WITH_VOCABULARY else Input(name=f, shape=(1,), dtype="float32") ) for f in FEATURE_NAMES } return inputs
HyperModel
Here we use the Autokeras Functional
API to construct a network of BlocksSSS which will be built into a KerasTuner HyperModel and finally to a Keras Model.
class MyHyperModel(keras_tuner.HyperModel): def build(self, hp): inputs = create_model_inputs() features = VariableSelection()(inputs) outputs = ClassifierHead(num_classes=2, multi_label=False)(features) model = Graph(inputs=inputs, outputs=outputs) model = model.build(hp) return model def fit(self, hp, model, *args, **kwargs): return model.fit( *args, # Tune whether to shuffle the data in each epoch. shuffle=hp.Boolean("shuffle"), **kwargs, )
RandomSearch
Tuner to find best HyperParametersWe use the RandomSearch tuner to serach for hyparameters in the search space We also display the search space
print("Start training and searching for the best model...") tuner = keras_tuner.RandomSearch( MyHyperModel(), objective="val_accuracy", max_trials=3, overwrite=True, directory="my_dir", project_name="tune_hypermodel", ) # Show the search space summary print("Tuner search space summary:\n") tuner.search_space_summary() # Search for best model tuner.search(train_ds, epochs=2, validation_data=val_ds)
Trial 3 Complete [00h 00m 16s] val_accuracy: 0.8032786846160889
Best val_accuracy So Far: 0.8032786846160889 Total elapsed time: 00h 00m 34s
# Get the top model. models = tuner.get_best_models(num_models=1) best_model = models[0] best_model.summary()
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/saving/saving_lib.py:757: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 346 variables. saveable.load_own_variables(weights_store.get(inner_path))
Model: "functional"
βββββββββββββββββββββββ³ββββββββββββββββββββ³βββββββββββββ³ββββββββββββββββββββ β Layer (type) β Output Shape β Param # β Connected to β β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β age (InputLayer) β (None, 1) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β chol (InputLayer) β (None, 1) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β oldpeak β (None, 1) β 0 β - β β (InputLayer) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β slope (InputLayer) β (None, 1) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β thalach β (None, 1) β 0 β - β β (InputLayer) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β trestbps β (None, 1) β 0 β - β β (InputLayer) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32 β (None, 1) β 0 β age[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β ca (InputLayer) β (None, 4) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_2 β (None, 1) β 0 β chol[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cp (InputLayer) β (None, 5) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β exang (InputLayer) β (None, 2) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β fbs (InputLayer) β (None, 2) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_6 β (None, 1) β 0 β oldpeak[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β restecg β (None, 3) β 0 β - β β (InputLayer) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β sex (InputLayer) β (None, 2) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_9 β (None, 1) β 0 β slope[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β thal (InputLayer) β (None, 5) β 0 β - β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_11 β (None, 1) β 0 β thalach[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_12 β (None, 1) β 0 β trestbps[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β normalization β (None, 1) β 3 β cast_to_float32[β¦ β β (Normalization) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_1 β (None, 4) β 0 β ca[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β normalization_1 β (None, 1) β 3 β cast_to_float32_β¦ β β (Normalization) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_3 β (None, 5) β 0 β cp[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_4 β (None, 2) β 0 β exang[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_5 β (None, 2) β 0 β fbs[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β normalization_2 β (None, 1) β 3 β cast_to_float32_β¦ β β (Normalization) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_7 β (None, 3) β 0 β restecg[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_8 β (None, 2) β 0 β sex[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β normalization_3 β (None, 1) β 3 β cast_to_float32_β¦ β β (Normalization) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β cast_to_float32_10 β (None, 5) β 0 β thal[0][0] β β (CastToFloat32) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β normalization_4 β (None, 1) β 3 β cast_to_float32_β¦ β β (Normalization) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β normalization_5 β (None, 1) β 3 β cast_to_float32_β¦ β β (Normalization) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense (Dense) β (None, 16) β 32 β normalization[0]β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_1 (Dense) β (None, 16) β 80 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_2 (Dense) β (None, 16) β 32 β normalization_1[β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_3 (Dense) β (None, 16) β 96 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_4 (Dense) β (None, 16) β 48 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_5 (Dense) β (None, 16) β 48 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_6 (Dense) β (None, 16) β 32 β normalization_2[β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_7 (Dense) β (None, 16) β 64 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_8 (Dense) β (None, 16) β 48 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_9 (Dense) β (None, 16) β 32 β normalization_3[β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_10 (Dense) β (None, 16) β 96 β cast_to_float32_β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_11 (Dense) β (None, 16) β 32 β normalization_4[β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_12 (Dense) β (None, 16) β 32 β normalization_5[β¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β concatenate β (None, 208) β 0 β dense[0][0], β β (Concatenate) β β β dense_1[0][0], β β β β β dense_2[0][0], β β β β β dense_3[0][0], β β β β β dense_4[0][0], β β β β β dense_5[0][0], β β β β β dense_6[0][0], β β β β β dense_7[0][0], β β β β β dense_8[0][0], β β β β β dense_9[0][0], β β β β β dense_10[0][0], β β β β β dense_11[0][0], β β β β β dense_12[0][0] β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 7,536 β concatenate[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_18 (Dense) β (None, 13) β 221 β gated_residual_nβ¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β expand_dims β (None, 13, 1) β 0 β dense_18[0][0] β β (ExpandDims) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_1[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_2[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_3[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_4[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_5[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_6[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_7[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_8[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_9[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_10[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_11[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β gated_residual_netβ¦ β (None, 16) β 1,120 β dense_12[0][0] β β (GatedResidualNetwβ¦ β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β transpose β (None, 1, 13) β 0 β expand_dims[0][0] β β (Transpose) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β stack (Stack) β (None, 13, 16) β 0 β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β β β β β gated_residual_nβ¦ β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β matmul (Matmul) β (None, 1, 16) β 0 β transpose[0][0], β β β β β stack[0][0] β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β squeeze (Squeeze) β (None, 16) β 0 β matmul[0][0] β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dropout_14 β (None, 16) β 0 β squeeze[0][0] β β (Dropout) β β β β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β dense_84 (Dense) β (None, 1) β 17 β dropout_14[0][0] β βββββββββββββββββββββββΌββββββββββββββββββββΌβββββββββββββΌββββββββββββββββββββ€ β classifier_head_1 β (None, 1) β 0 β dense_84[0][0] β β (Activation) β β β β βββββββββββββββββββββββ΄ββββββββββββββββββββ΄βββββββββββββ΄ββββββββββββββββββββ
Total params: 23,024 (89.96 KB)
Trainable params: 23,006 (89.87 KB)
Non-trainable params: 18 (96.00 B)
To get a prediction for a new sample, you can simply call model.predict()
. There are just two things you need to do:
convert_to_tensor
on each featuresample = { "age": 60, "sex": 1, "cp": 1, "trestbps": 145, "chol": 233, "fbs": 1, "restecg": 2, "thalach": 150, "exang": 0, "oldpeak": 2.3, "slope": 3, "ca": 0, "thal": "fixed", } # Given the category (in the sample above - key) and the category value (in the sample above - value), # we return its one-hot encoding def get_cat_encoding(cat, cat_value): # Create a list of zeros with the same length as categories encoding = [0] * len(cat) # Find the index of category_value in categories and set the corresponding position to 1 if cat_value in cat: encoding[cat.index(cat_value)] = 1 return encoding for name, value in sample.items(): if name in CATEGORICAL_FEATURES_WITH_VOCABULARY: sample.update( { name: get_cat_encoding( CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name] ) } ) # Convert inputs to tensors input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()} predictions = best_model.predict(input_dict) print( f"This particular patient had a {100 * predictions[0][0]:.1f} " "percent probability of having a heart disease, " "as evaluated by our model." )
1/1 ββββββββββββββββββββ 0s 136ms/step
1/1 ββββββββββββββββββββ 0s 162ms/step
This particular patient had a 28.1 percent probability of having a heart disease, as evaluated by our model.