Optimización con torch
Contents
Optimización con torch#
Introducción#
Ahora que tenemos un modelo y datos, es hora de entrenar, validar y probar nuestro modelo optimizando sus parámetros en nuestros datos.
Entrenar un modelo es un proceso iterativo; en cada iteración (llamada época) el modelo hace una suposición sobre la salida, calcula el error en su suposición (pérdida), recopila las derivadas del error con respecto a sus parámetros (como vimos en la sección anterior) y optimiza estos parámetros usando el descenso de gradiente.
Implementación completa#
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
# datos
training_data = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root='data',
train=False,
download=True,
transform=ToTensor(),
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
# modelo
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
# hiperparámetros
learning_rate = 1e-3
batch_size = 64
epochs = 5
# loop de entrenamiento
# función de pérdida
loss_fn = nn.CrossEntropyLoss()
# Optimizador
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# función de entrenamiento
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# calcula predicción y pérdida
pred = model(X)
loss = loss_fn(pred, y)
# backpropagation
optimizer.zero_grad()# por defecto el calculo del gradiente es acumulativo
loss.backward() # clacula y acumula los gradientes
optimizer.step()# paso de optimización x += grad
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')
# función de validación
def test_loop(testloader, model, loss_fn):
size = len(testloader.dataset)
num_batches = len(testloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in testloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1)==y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# ciclo de entrenamiento
epochs = 10
for t in range(epochs):
print(f'Epoch {t+1}\n -------------------------')
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print('Done!')
Epoch 1
-------------------------
loss: 2.304489 [ 0/60000]
loss: 2.301195 [ 6400/60000]
loss: 2.293884 [12800/60000]
loss: 2.286755 [19200/60000]
loss: 2.272739 [25600/60000]
loss: 2.251779 [32000/60000]
loss: 2.252163 [38400/60000]
loss: 2.235502 [44800/60000]
loss: 2.224388 [51200/60000]
loss: 2.219147 [57600/60000]
Test Error:
Accuracy: 45.5%, Avg loss: 2.214530
Epoch 2
-------------------------
loss: 2.210820 [ 0/60000]
loss: 2.207516 [ 6400/60000]
loss: 2.182261 [12800/60000]
loss: 2.203938 [19200/60000]
loss: 2.155809 [25600/60000]
loss: 2.100431 [32000/60000]
loss: 2.140475 [38400/60000]
loss: 2.089364 [44800/60000]
loss: 2.085941 [51200/60000]
loss: 2.058784 [57600/60000]
Test Error:
Accuracy: 51.8%, Avg loss: 2.058807
Epoch 3
-------------------------
loss: 2.048604 [ 0/60000]
loss: 2.039883 [ 6400/60000]
loss: 1.980494 [12800/60000]
loss: 2.051986 [19200/60000]
loss: 1.935781 [25600/60000]
loss: 1.855610 [32000/60000]
loss: 1.935602 [38400/60000]
loss: 1.840926 [44800/60000]
loss: 1.859519 [51200/60000]
loss: 1.801622 [57600/60000]
Test Error:
Accuracy: 53.6%, Avg loss: 1.806677
Epoch 4
-------------------------
loss: 1.787565 [ 0/60000]
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/tmp/ipykernel_3226/4024397361.py in <module>
96 for t in range(epochs):
97 print(f'Epoch {t+1}\n -------------------------')
---> 98 train_loop(train_dataloader, model, loss_fn, optimizer)
99 test_loop(test_dataloader, model, loss_fn)
100 print('Done!')
/tmp/ipykernel_3226/4024397361.py in train_loop(dataloader, model, loss_fn, optimizer)
64 for batch, (X, y) in enumerate(dataloader):
65 # calcula predicción y pérdida
---> 66 pred = model(X)
67 loss = loss_fn(pred, y)
68
/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_3226/4024397361.py in forward(self, x)
40 def forward(self, x):
41 x = self.flatten(x)
---> 42 logits = self.linear_relu_stack(x)
43 return logits
44
/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/nn/modules/container.py in forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
143
/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/hostedtoolcache/Python/3.7.13/x64/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
101
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
105 def extra_repr(self) -> str:
KeyboardInterrupt: