Modelo Lineal de Clasificación con JAX
Contents
Modelo Lineal de Clasificación con JAX#
Introducción#
Con su versión actualizada de Autograd, JAX puede diferenciar automáticamente el código nativo de Python y NumPy. Puede derivarse a través de un gran subconjunto de características de Python, incluidos bucles, condicionales, recursión y closures, e incluso puede tomar derivadas de derivadas de derivadas. Admite la diferenciación tanto en modo inverso como en modo directo, y los dos pueden componerse arbitrariamente en cualquier orden.
Lo nuevo es que JAX usa XLA para compilar y ejecutar su código NumPy en aceleradores, como GPU y TPU. La compilación ocurre de forma predeterminada, con las llamadas de la biblioteca compiladas y ejecutadas justo a tiempo. Pero JAX incluso le permite compilar justo a tiempo sus propias funciones de Python en núcleos optimizados para XLA utilizando una API de una función. La compilación y la diferenciación automática se pueden componer de forma arbitraria, por lo que puede expresar algoritmos sofisticados y obtener el máximo rendimiento sin tener que abandonar Python.
# !pip install --upgrade jax jaxlib
from __future__ import print_function
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
# La convención actual es: import numpy original as "onp"
import numpy as onp
import itertools
#import random
#import jax
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Input In [1], in <cell line: 4>()
1 # !pip install --upgrade jax jaxlib
3 from __future__ import print_function
----> 4 import jax.numpy as np
5 from jax import grad, jit, vmap
6 from jax import random
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\jax\__init__.py:21, in <module>
18 del _os
20 # flake8: noqa: F401
---> 21 from .config import config
22 from .api import (
23 ad, # TODO(phawkins): update users to avoid this.
24 argnums_partial, # TODO(phawkins): update Haiku to not use this.
(...)
87 xla_computation,
88 )
89 from .experimental.maps import soft_pmap
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\jax\config.py:19, in <module>
17 import threading
18 from typing import Optional
---> 19 from jax import lib
21 def bool_env(varname: str, default: bool) -> bool:
22 """Read an environment variable and interpret it as a boolean.
23
24 True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
(...)
30 Raises: ValueError if the environment variable is anything else.
31 """
File ~\AppData\Local\Programs\Python\Python39\lib\site-packages\jax\lib\__init__.py:23, in <module>
1 # Copyright 2018 Google LLC
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
15 # This module is largely a wrapper around `jaxlib` that performs version
16 # checking on import.
18 __all__ = [
19 'cuda_prng', 'cusolver', 'rocsolver', 'jaxlib', 'lapack',
20 'pocketfft', 'pytree', 'tpu_client', 'version', 'xla_client'
21 ]
---> 23 import jaxlib
25 # Must be kept in sync with the jaxlib version in build/test-requirements.txt
26 _minimum_jaxlib_version = (0, 1, 60)
ModuleNotFoundError: No module named 'jaxlib'
Función de predicción#
def sigmoid(x):
return 0.5*(np.tanh(x/2)+1)
# más estable que 1.0/(1+np.exp(-x))
# genera la probabilidad de que una etiqueta sea verdadera
def predict(W,b,inputs):
return sigmoid(np.dot(inputs,W)+b)
Función de pérdida. Entropía cruzada#
# función de pérdida: -log de verosimilitud de ejemplos de entrenamiento
def loss(W,b,x,y):
preds = predict(W,b,x)
label_probs = preds*y + (1-preds)*(1-y)
return -np.sum(np.log(label_probs))
# inicializar coeficientes
key, W_key, b_key = random.split(key,3)
W = random.normal(key, (3,))
b = random.normal(key,())
Ejemplo. Datos de Juguete#
# Creando un dataset de juguete
inputs = np.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])
Gradiente#
Usaremos la función grad con sus argumentos para diferenciar la función con respecto a sus parámetros posicionales.
# compilar con jit
# argnums define parámetros posicionales para derivar con respecto a
grad_loss = jit(grad(loss,argnums=(0,1)))
W_grad, b_grad = grad_loss(W,b,inputs, targets)
print("W_grad = ", W_grad)
print("b_grad = ", b_grad)
W_grad = [ 0.15979266 0.15962079 -1.4914058 ]
b_grad = 0.42253572
Entrenamiento del modelo#
# función de entrenamiento
def train(W,b,x,y, lr= 0.12):
gradient = grad_loss(W,b,inputs,targets)
W_grad, b_grad = grad_loss(W,b,inputs,targets)
W -= W_grad*lr
b -= b_grad*lr
return(W,b)
# entrenamiento
weights, biases = [], []
train_loss= []
epochs = 20
train_loss.append(loss(W,b,inputs,targets))
for epoch in range(epochs):
W,b = train(W,b,inputs, targets)
weights.append(W)
biases.append(b)
losss = loss(W,b,inputs,targets)
train_loss.append(losss)
print(f"Epoch {epoch}: train loss {losss}")
Epoch 0: train loss 2.2908685207366943
Epoch 1: train loss 2.0348708629608154
Epoch 2: train loss 1.8085304498672485
Epoch 3: train loss 1.6108163595199585
Epoch 4: train loss 1.4400672912597656
Epoch 5: train loss 1.2939282655715942
Epoch 6: train loss 1.1695582866668701
Epoch 7: train loss 1.0639365911483765
Epoch 8: train loss 0.9741388559341431
Epoch 9: train loss 0.8975158333778381
Epoch 10: train loss 0.8317785263061523
Epoch 11: train loss 0.7750089764595032
Epoch 12: train loss 0.7256337404251099
Epoch 13: train loss 0.6823759078979492
Epoch 14: train loss 0.6442046165466309
Epoch 15: train loss 0.6102899312973022
Epoch 16: train loss 0.5799612998962402
Epoch 17: train loss 0.5526753664016724
Epoch 18: train loss 0.5279892683029175
Epoch 19: train loss 0.5055401921272278
print('weights')
for weight in weights:
print(weight)
print('biases')
for bias in biases:
print(bias)
weights
[ 0.94362557 -0.27246025 -0.08247474]
[ 0.92442816 -0.2868005 0.08655086]
[ 0.90611976 -0.2978541 0.24569045]
[ 0.8893553 -0.30661893 0.39487803]
[ 0.8745658 -0.31376946 0.5341197 ]
[ 0.86198187 -0.31977373 0.66359735]
[ 0.8516655 -0.32496306 0.783694 ]
[ 0.84355026 -0.3295752 0.8949632 ]
[ 0.8374844 -0.33378202 0.9980703 ]
[ 0.8332691 -0.3377079 1.0937309]
[ 0.83068764 -0.34144276 1.1826608 ]
[ 0.8295252 -0.34505108 1.2655417 ]
[ 0.82958055 -0.34857863 1.3430017 ]
[ 0.8306719 -0.35205755 1.4156077 ]
[ 0.8326388 -0.35550994 1.4838635 ]
[ 0.83534175 -0.35895056 1.5482135 ]
[ 0.83866084 -0.36238888 1.6090478 ]
[ 0.8424935 -0.36583066 1.6667081 ]
[ 0.84675235 -0.36927888 1.721494 ]
[ 0.85136336 -0.3727346 1.7736678 ]
biases
1.0681342
1.018281
0.97071755
0.9265029
0.8863388
0.8505914
0.81933606
0.79242367
0.7695556
0.75035113
0.73439974
0.72129667
0.7106637
0.70215917
0.6954815
0.6903681
0.6865927
0.6839612
0.68230784
0.68149114
print(grad(loss)(W,b,inputs,targets))
[-0.0408358 0.02885852 -0.4149384 ]
Calculando el valor de la función y el gradiente con value_and_grad#
from jax import value_and_grad
loss_val, Wb_grad = value_and_grad(loss,(0,1))(W,b,inputs, targets)
print('loss value: ', loss_val)
print('gradient value: ', Wb_grad)
loss value: 0.5055402
gradient value: (DeviceArray([-0.0408358 , 0.02885852, -0.4149384 ], dtype=float32), DeviceArray(0.00084008, dtype=float32))