Subclassing-Modelo de Regresión multi-logística#

Introducción#

Esta lección está dedicada a presentar la programación orientada a objetos aplicada a Keras, conocida genéricamente coo subclassing.

El tema es recomendado para usuarios con conocimientos en programación y keras en nuestro caso.

Haremos el modelo de clasificación con múltiples categorías, utilizado en la introducción de la API funcional. El ejemplo es de nuevo Iris, con el propósito de comparar los dos estilos de programación.

Importa módulos#

try:
  %tensorflow_version 2.x
except Exception:
  pass
from __future__ import absolute_import, division, print_function, unicode_literals
#
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
#
from tensorflow.keras.models import Model
#
from tensorflow.keras.layers import Dense, Input, Activation
#
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.utils import plot_model
#
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
#
print(tf.__version__)
2.9.1

El conjunto de datos Iris#

Este conjunto de datos fue introducido por sir Ronald Fisher

Lectura de datos#

# nombres de las columnas de los datos
col_names = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
target_dimensions = ['Setosa', 'Versicolor', 'Virginica']

# lee los datos
training_data_path = tf.keras.utils.get_file("iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_data_path = tf.keras.utils.get_file("iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

training = pd.read_csv(training_data_path, names=col_names, header=0)
test = pd.read_csv(test_data_path, names=col_names, header=0)

Pre-procesamiento#

La variable objetivo (target) tiene tres categorías. Usaremos la codificación one-hot.

Codificación one-hot#

y_train= pd.DataFrame(to_categorical(training.Species))
y_train.columns = target_dimensions

y_test = pd.DataFrame(to_categorical(test.Species))
y_test.columns = target_dimensions

Elimina columna Species #

training.drop(['Species'], axis=1, inplace=True)
#test.drop(['Species'], axis=1, inplace=True)
y_test_species = test.pop('Species') # extrae la columna y la coloca en y_test_species
#
#Si necesita subir al dataframe la recodificación use estas líneas
#training = training.join(labels_training )
#test = test.join(labels_test )

Normaliza los features#

StandardScaler#

# crea el objeto StandardScaler
scaler = StandardScaler()

# Ajusta los parámetros del scaler
scaler.fit(training)
print (scaler.mean_)

# escala training y test
x_train = scaler.transform(training)
x_test = scaler.transform(test)

# labels ( no requieren escalación)
[5.845      3.065      3.73916667 1.19666667]

Crea el modelo usando subclassing#

Vamos a derivar una clase de la clase Model

class CustomModel(Model):
    
    def __init__(self, **kwargs):
        super(CustomModel, self).__init__(**kwargs)
        self.dense1 = Dense(5, activation='relu', )
        self.dense2 = Dense(10, activation='relu')
        self.dense3 = Dense(3, activation='softmax')
        
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.dense3(x)
    

Crea una instancia de CustomModel#

model_iris = CustomModel(name='my_custom_model')

Compila#

model_iris.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

Entrena#

history = model_iris.fit(x_train, y_train.values,
                    batch_size= 64,
                    epochs= 30,
                    validation_split=0.2)
Epoch 1/30
1/2 [==============>...............] - ETA: 1s - loss: 1.0499 - accuracy: 0.1875

2/2 [==============================] - 2s 512ms/step - loss: 1.0416 - accuracy: 0.2396 - val_loss: 1.0060 - val_accuracy: 0.3333
Epoch 2/30
1/2 [==============>...............] - ETA: 0s - loss: 1.0133 - accuracy: 0.2500

2/2 [==============================] - 0s 40ms/step - loss: 1.0038 - accuracy: 0.2812 - val_loss: 0.9706 - val_accuracy: 0.3750
Epoch 3/30
1/2 [==============>...............] - ETA: 0s - loss: 0.9837 - accuracy: 0.4062

2/2 [==============================] - 0s 47ms/step - loss: 0.9763 - accuracy: 0.3958 - val_loss: 0.9398 - val_accuracy: 0.7917
Epoch 4/30
1/2 [==============>...............] - ETA: 0s - loss: 0.9531 - accuracy: 0.5938

2/2 [==============================] - 0s 47ms/step - loss: 0.9539 - accuracy: 0.5729 - val_loss: 0.9155 - val_accuracy: 0.7917
Epoch 5/30
1/2 [==============>...............] - ETA: 0s - loss: 0.9179 - accuracy: 0.6875

2/2 [==============================] - 0s 47ms/step - loss: 0.9361 - accuracy: 0.5938 - val_loss: 0.8963 - val_accuracy: 0.7917
Epoch 6/30
1/2 [==============>...............] - ETA: 0s - loss: 0.9120 - accuracy: 0.6406

2/2 [==============================] - 0s 48ms/step - loss: 0.9204 - accuracy: 0.6042 - val_loss: 0.8786 - val_accuracy: 0.7917
Epoch 7/30
1/2 [==============>...............] - ETA: 0s - loss: 0.9022 - accuracy: 0.6562

2/2 [==============================] - 0s 56ms/step - loss: 0.9054 - accuracy: 0.6458 - val_loss: 0.8585 - val_accuracy: 0.7917
Epoch 8/30
1/2 [==============>...............] - ETA: 0s - loss: 0.9011 - accuracy: 0.6562

2/2 [==============================] - 0s 56ms/step - loss: 0.8904 - accuracy: 0.6562 - val_loss: 0.8408 - val_accuracy: 0.7917
Epoch 9/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8516 - accuracy: 0.7656

2/2 [==============================] - 0s 56ms/step - loss: 0.8772 - accuracy: 0.6979 - val_loss: 0.8265 - val_accuracy: 0.7917
Epoch 10/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8566 - accuracy: 0.7031

2/2 [==============================] - 0s 56ms/step - loss: 0.8650 - accuracy: 0.6875 - val_loss: 0.8112 - val_accuracy: 0.7917
Epoch 11/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8168 - accuracy: 0.7344

2/2 [==============================] - 0s 48ms/step - loss: 0.8534 - accuracy: 0.6875 - val_loss: 0.7983 - val_accuracy: 0.7917
Epoch 12/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8178 - accuracy: 0.7031

2/2 [==============================] - 0s 63ms/step - loss: 0.8424 - accuracy: 0.6875 - val_loss: 0.7858 - val_accuracy: 0.7917
Epoch 13/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8295 - accuracy: 0.7188

2/2 [==============================] - 0s 64ms/step - loss: 0.8312 - accuracy: 0.6979 - val_loss: 0.7721 - val_accuracy: 0.7917
Epoch 14/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8603 - accuracy: 0.6719

2/2 [==============================] - 0s 56ms/step - loss: 0.8199 - accuracy: 0.7083 - val_loss: 0.7573 - val_accuracy: 0.7917
Epoch 15/30
1/2 [==============>...............] - ETA: 0s - loss: 0.8047 - accuracy: 0.7188

2/2 [==============================] - 0s 48ms/step - loss: 0.8088 - accuracy: 0.7083 - val_loss: 0.7458 - val_accuracy: 0.7917
Epoch 16/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7935 - accuracy: 0.7344

2/2 [==============================] - 0s 48ms/step - loss: 0.7987 - accuracy: 0.7500 - val_loss: 0.7341 - val_accuracy: 0.7917
Epoch 17/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7867 - accuracy: 0.7500

2/2 [==============================] - 0s 40ms/step - loss: 0.7889 - accuracy: 0.7396 - val_loss: 0.7230 - val_accuracy: 0.7500
Epoch 18/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7490 - accuracy: 0.7344

2/2 [==============================] - 0s 56ms/step - loss: 0.7797 - accuracy: 0.7500 - val_loss: 0.7133 - val_accuracy: 0.7500
Epoch 19/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7778 - accuracy: 0.7656

2/2 [==============================] - 0s 48ms/step - loss: 0.7703 - accuracy: 0.7500 - val_loss: 0.7022 - val_accuracy: 0.7500
Epoch 20/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7792 - accuracy: 0.7344

2/2 [==============================] - 0s 56ms/step - loss: 0.7611 - accuracy: 0.7500 - val_loss: 0.6909 - val_accuracy: 0.7500
Epoch 21/30
1/2 [==============>...............] - ETA: 0s - loss: 0.6992 - accuracy: 0.8125

2/2 [==============================] - 0s 56ms/step - loss: 0.7526 - accuracy: 0.7500 - val_loss: 0.6816 - val_accuracy: 0.7500
Epoch 22/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7605 - accuracy: 0.7188

2/2 [==============================] - 0s 47ms/step - loss: 0.7440 - accuracy: 0.7500 - val_loss: 0.6706 - val_accuracy: 0.7500
Epoch 23/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7207 - accuracy: 0.7500

2/2 [==============================] - 0s 63ms/step - loss: 0.7352 - accuracy: 0.7500 - val_loss: 0.6614 - val_accuracy: 0.7500
Epoch 24/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7275 - accuracy: 0.7656

2/2 [==============================] - 0s 56ms/step - loss: 0.7268 - accuracy: 0.7500 - val_loss: 0.6520 - val_accuracy: 0.7500
Epoch 25/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7519 - accuracy: 0.7344

2/2 [==============================] - 0s 56ms/step - loss: 0.7187 - accuracy: 0.7500 - val_loss: 0.6424 - val_accuracy: 0.7500
Epoch 26/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7255 - accuracy: 0.7500

2/2 [==============================] - 0s 64ms/step - loss: 0.7106 - accuracy: 0.7500 - val_loss: 0.6338 - val_accuracy: 0.7500
Epoch 27/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7265 - accuracy: 0.7188

2/2 [==============================] - 0s 64ms/step - loss: 0.7029 - accuracy: 0.7500 - val_loss: 0.6249 - val_accuracy: 0.7500
Epoch 28/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7066 - accuracy: 0.7500

2/2 [==============================] - 0s 48ms/step - loss: 0.6953 - accuracy: 0.7500 - val_loss: 0.6168 - val_accuracy: 0.7500
Epoch 29/30
1/2 [==============>...............] - ETA: 0s - loss: 0.7049 - accuracy: 0.7188

2/2 [==============================] - 0s 64ms/step - loss: 0.6882 - accuracy: 0.7500 - val_loss: 0.6098 - val_accuracy: 0.7500
Epoch 30/30
1/2 [==============>...............] - ETA: 0s - loss: 0.6614 - accuracy: 0.7969

2/2 [==============================] - 0s 72ms/step - loss: 0.6811 - accuracy: 0.7500 - val_loss: 0.6023 - val_accuracy: 0.7500
model_iris.summary()
#plot_model(model_iris, to_file='../Imagenes/iris_model.png', 
#           show_shapes=True)
Model: "my_custom_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               multiple                  25        
                                                                 
 dense_1 (Dense)             multiple                  60        
                                                                 
 dense_2 (Dense)             multiple                  33        
                                                                 
=================================================================
Total params: 118
Trainable params: 118
Non-trainable params: 0
_________________________________________________________________

Evaluación del modelo#

def plot_metric(history, metric):
    train_metrics = history.history[metric]
    val_metrics = history.history['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Entrenamiento y validación '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.show()
plot_metric(history, 'loss')
../../_images/am-subclassing-iris_31_0.png
plot_metric(history, 'accuracy')
../../_images/am-subclassing-iris_32_0.png
model_iris.evaluate(x = x_test,y = y_test.values)
1/1 [==============================] - ETA: 0s - loss: 0.7537 - accuracy: 0.6000

1/1 [==============================] - 0s 62ms/step - loss: 0.7537 - accuracy: 0.6000
[0.7536590099334717, 0.6000000238418579]

Predicciones#

# Predicting the Test set results
y_pred = model_iris.predict(x_test)
y_pred_c = np.argmax(y_pred, axis=1)
1/1 [==============================] - ETA: 0s

1/1 [==============================] - 0s 272ms/step

Matriz de confusión#

cm = confusion_matrix(y_test_species, y_pred_c)
print("Our accuracy is {}%".format(((cm[0][0] + cm[1][1]+ cm[2][2])/y_test_species.shape[0])*100))
Our accuracy is 60.0%
sns.heatmap(cm,annot=True)
plt.savefig('h.png')
../../_images/am-subclassing-iris_39_0.png