LightningModule
LightningModule er en basisklasse i PyTorch Lightning, der definerer en grænseflade til organisering af deep learning-modeller, herunder forward-pass, træningstrin og konfiguration.
Kort fortalt
En LightningModule er en baseklasse i PyTorch Lightning der samler model, optimering og træningslogik i én struktureret enhed.
- Kategori
- arkitektur
- Niveau
- øvet
- Udtale
- ˈlaɪtnɪŋ ˈmɒdjuːl
Betydninger
1- 1
En basisklasse i PyTorch Lightning, der organiserer deep learning-modeller ved at samle netværksarkitektur, træningslogik og optimeringskonfiguration i én arvelig klasse.
- Ved at arve fra LightningModule kan man definere både forward-pass og træningstrin i samme klasse. — PyTorch Lightning dokumentation
- LightningModule håndterer automatisk flytning af modellen til den rigtige enhed under træning. — PyTorch Lightning dokumentation
Hvornår bruges det
Bruges når man vil skrive vedligeholdelig deep learning-kode. Man arver fra LightningModule og implementerer metoder som training_step, configure_optimizers og forward. Gør koden agnostisk over for hardware (GPU/TPU) og skalerbar.
Kodeeksempel
import lightning as L
import torch.nn as nn
import torch.optim as optim
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
def training_step(self, batch, batch_idx):
x, y = batch
pred = self(x)
loss = nn.functional.mse_loss(pred, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)Eksempel på en LightningModule med en lineær model, træningstrin og optimeringskonfiguration.
Oprindelse
Navnet stammer fra PyTorch Lightning, et framework for deep learning, og 'Module' fra PyTorchs nn.Module. Introduceret af William Falcon i 2019.