Modelo Lineal de Clasificación con JAX#

Introducción#

Con su versión actualizada de Autograd, JAX puede diferenciar automáticamente el código nativo de Python y NumPy. Puede derivarse a través de un gran subconjunto de características de Python, incluidos bucles, condicionales, recursión y closures, e incluso puede tomar derivadas de derivadas de derivadas. Admite la diferenciación tanto en modo inverso como en modo directo, y los dos pueden componerse arbitrariamente en cualquier orden.

Lo nuevo es que JAX usa XLA para compilar y ejecutar su código NumPy en aceleradores, como GPU y TPU. La compilación ocurre de forma predeterminada, con las llamadas de la biblioteca compiladas y ejecutadas justo a tiempo. Pero JAX incluso le permite compilar justo a tiempo sus propias funciones de Python en núcleos optimizados para XLA utilizando una API de una función. La compilación y la diferenciación automática se pueden componer de forma arbitraria, por lo que puede expresar algoritmos sofisticados y obtener el máximo rendimiento sin tener que abandonar Python.

# !pip install --upgrade jax jaxlib 

from __future__ import print_function
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
# La convención actual es: import numpy original as "onp"
import numpy as onp
import itertools

#import random
#import jax
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Input In [1], in <cell line: 4>()
      1 # !pip install --upgrade jax jaxlib 
      3 from __future__ import print_function
----> 4 import jax.numpy as np
      5 from jax import grad, jit, vmap
      6 from jax import random

File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\jax\__init__.py:21, in <module>
     18 del _os
     20 # flake8: noqa: F401
---> 21 from .config import config
     22 from .api import (
     23   ad,  # TODO(phawkins): update users to avoid this.
     24   argnums_partial,  # TODO(phawkins): update Haiku to not use this.
   (...)
     87   xla_computation,
     88 )
     89 from .experimental.maps import soft_pmap

File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\jax\config.py:19, in <module>
     17 import threading
     18 from typing import Optional
---> 19 from jax import lib
     21 def bool_env(varname: str, default: bool) -> bool:
     22   """Read an environment variable and interpret it as a boolean.
     23 
     24   True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
   (...)
     30   Raises: ValueError if the environment variable is anything else.
     31   """

File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\jax\lib\__init__.py:23, in <module>
      1 # Copyright 2018 Google LLC
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     15 # This module is largely a wrapper around `jaxlib` that performs version
     16 # checking on import.
     18 __all__ = [
     19   'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
     20   'pocketfft', 'pytree', 'tpu_client', 'version', 'xla_client'
     21 ]
---> 23 import jaxlib
     25 # Must be kept in sync with the jaxlib version in build/test-requirements.txt
     26 _minimum_jaxlib_version = (0, 1, 60)

ModuleNotFoundError: No module named 'jaxlib'

Función de predicción#

def sigmoid(x):
    return 0.5*(np.tanh(x/2)+1)
# más estable que  1.0/(1+np.exp(-x))

# genera la probabilidad de que una etiqueta sea verdadera
def predict(W,b,inputs):
    return sigmoid(np.dot(inputs,W)+b)

Función de pérdida. Entropía cruzada#

# función de pérdida: -log de verosimilitud de ejemplos de entrenamiento
def loss(W,b,x,y):
    preds = predict(W,b,x)
    label_probs = preds*y + (1-preds)*(1-y)
    return -np.sum(np.log(label_probs))

# inicializar coeficientes
key, W_key, b_key = random.split(key,3)
W = random.normal(key, (3,))
b = random.normal(key,())

Ejemplo. Datos de Juguete#

# Creando un dataset de juguete
inputs = np.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])

Gradiente#

Usaremos la función grad con sus argumentos para diferenciar la función con respecto a sus parámetros posicionales.

# compilar con jit
# argnums define parámetros posicionales para derivar con respecto a
grad_loss = jit(grad(loss,argnums=(0,1)))
W_grad, b_grad = grad_loss(W,b,inputs, targets)
print("W_grad = ", W_grad)
print("b_grad = ", b_grad)
W_grad =  [ 0.15979266  0.15962079 -1.4914058 ]
b_grad =  0.42253572

Entrenamiento del modelo#

# función de entrenamiento
def train(W,b,x,y, lr= 0.12):
    gradient = grad_loss(W,b,inputs,targets) 
    W_grad, b_grad = grad_loss(W,b,inputs,targets)
    W -= W_grad*lr
    b -= b_grad*lr
    return(W,b)
# entrenamiento
weights, biases = [], []
train_loss= []
epochs = 20

train_loss.append(loss(W,b,inputs,targets))

for epoch in range(epochs):
    W,b = train(W,b,inputs, targets)
    weights.append(W)
    biases.append(b)
    losss = loss(W,b,inputs,targets)
    train_loss.append(losss)
    print(f"Epoch {epoch}: train loss {losss}")
Epoch 0: train loss 2.2908685207366943
Epoch 1: train loss 2.0348708629608154
Epoch 2: train loss 1.8085304498672485
Epoch 3: train loss 1.6108163595199585
Epoch 4: train loss 1.4400672912597656
Epoch 5: train loss 1.2939282655715942
Epoch 6: train loss 1.1695582866668701
Epoch 7: train loss 1.0639365911483765
Epoch 8: train loss 0.9741388559341431
Epoch 9: train loss 0.8975158333778381
Epoch 10: train loss 0.8317785263061523
Epoch 11: train loss 0.7750089764595032
Epoch 12: train loss 0.7256337404251099
Epoch 13: train loss 0.6823759078979492
Epoch 14: train loss 0.6442046165466309
Epoch 15: train loss 0.6102899312973022
Epoch 16: train loss 0.5799612998962402
Epoch 17: train loss 0.5526753664016724
Epoch 18: train loss 0.5279892683029175
Epoch 19: train loss 0.5055401921272278
print('weights')
for weight in weights:
    print(weight)
print('biases')
for bias in biases:
    print(bias)
weights
[ 0.94362557 -0.27246025 -0.08247474]
[ 0.92442816 -0.2868005   0.08655086]
[ 0.90611976 -0.2978541   0.24569045]
[ 0.8893553  -0.30661893  0.39487803]
[ 0.8745658  -0.31376946  0.5341197 ]
[ 0.86198187 -0.31977373  0.66359735]
[ 0.8516655  -0.32496306  0.783694  ]
[ 0.84355026 -0.3295752   0.8949632 ]
[ 0.8374844  -0.33378202  0.9980703 ]
[ 0.8332691 -0.3377079  1.0937309]
[ 0.83068764 -0.34144276  1.1826608 ]
[ 0.8295252  -0.34505108  1.2655417 ]
[ 0.82958055 -0.34857863  1.3430017 ]
[ 0.8306719  -0.35205755  1.4156077 ]
[ 0.8326388  -0.35550994  1.4838635 ]
[ 0.83534175 -0.35895056  1.5482135 ]
[ 0.83866084 -0.36238888  1.6090478 ]
[ 0.8424935  -0.36583066  1.6667081 ]
[ 0.84675235 -0.36927888  1.721494  ]
[ 0.85136336 -0.3727346   1.7736678 ]
biases
1.0681342
1.018281
0.97071755
0.9265029
0.8863388
0.8505914
0.81933606
0.79242367
0.7695556
0.75035113
0.73439974
0.72129667
0.7106637
0.70215917
0.6954815
0.6903681
0.6865927
0.6839612
0.68230784
0.68149114
print(grad(loss)(W,b,inputs,targets))
[-0.0408358   0.02885852 -0.4149384 ]

Calculando el valor de la función y el gradiente con value_and_grad#

from jax import value_and_grad
loss_val, Wb_grad = value_and_grad(loss,(0,1))(W,b,inputs, targets)
print('loss value: ', loss_val)
print('gradient value: ', Wb_grad)
loss value:  0.5055402
gradient value:  (DeviceArray([-0.0408358 ,  0.02885852, -0.4149384 ], dtype=float32), DeviceArray(0.00084008, dtype=float32))