LightningDataModule
En klasse i PyTorch Lightning, der organiserer dataindlæsning, præprocessing og opdeling i trænings-, validerings- og testdatasæt på en genanvendelig måde.
Kort fortalt
LightningDataModule er en standardiseret måde at håndtere data på i PyTorch Lightning, så du kan genbruge din datakode på tværs af projekter og nemt skifte mellem forskellige datasæt.
- Kategori
- værktøj
- Niveau
- øvet
- Udtale
- /ˈlaɪtnɪŋˈdeɪtəˌmɒdjuːl/
Betydninger
1- 1
En basisklasse i PyTorch Lightning, der indkapsler al datarelateret funktionalitet: download, forbehandling, opdeling i trænings-, validerings- og testsæt samt oprettelse af dataloadere.
- Ved at arve fra LightningDataModule kan du nemt skifte mellem forskellige datasæt uden at ændre træningsloopet. — PyTorch Lightning dokumentation
- En LightningDataModule håndterer typisk både datadownload og datatransformationer.
Hvornår bruges det
Bruges i PyTorch Lightning-projekter til at adskille datarelateret logik fra model- og træningslogik. Du opretter en underklasse af LightningDataModule og implementerer metoder som prepare_data, setup, train_dataloader, val_dataloader og test_dataloader. Det gør koden mere modulær og testbar.
Kodeeksempel
import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './', batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)Eksempel på en LightningDataModule til MNIST-datasættet. Klassen håndterer download, opdeling og dataloadere.
Oprindelse
Sammensat af 'Lightning' (fra PyTorch Lightning), 'Data' (data) og 'Module' (modul), der angiver en datamodul-komponent i Lightning-rammen.