LightningDataModule

En klasse i PyTorch Lightning, der organiserer dataindlæsning, præprocessing og opdeling i trænings-, validerings- og testdatasæt på en genanvendelig måde.

Kort fortalt

LightningDataModule er en standardiseret måde at håndtere data på i PyTorch Lightning, så du kan genbruge din datakode på tværs af projekter og nemt skifte mellem forskellige datasæt.

Kategori
værktøj
Niveau
øvet
Udtale
/ˈlaɪtnɪŋˈdeɪtəˌmɒdjuːl/

Betydninger

1
  1. 1

    En basisklasse i PyTorch Lightning, der indkapsler al datarelateret funktionalitet: download, forbehandling, opdeling i trænings-, validerings- og testsæt samt oprettelse af dataloadere.

    • Ved at arve fra LightningDataModule kan du nemt skifte mellem forskellige datasæt uden at ændre træningsloopet.PyTorch Lightning dokumentation
    • En LightningDataModule håndterer typisk både datadownload og datatransformationer.

Hvornår bruges det

Bruges i PyTorch Lightning-projekter til at adskille datarelateret logik fra model- og træningslogik. Du opretter en underklasse af LightningDataModule og implementerer metoder som prepare_data, setup, train_dataloader, val_dataloader og test_dataloader. Det gør koden mere modulær og testbar.

Kodeeksempel

import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = './', batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

Eksempel på en LightningDataModule til MNIST-datasættet. Klassen håndterer download, opdeling og dataloadere.

Oprindelse

Sammensat af 'Lightning' (fra PyTorch Lightning), 'Data' (data) og 'Module' (modul), der angiver en datamodul-komponent i Lightning-rammen.

Afledte ord

2

Kilder

1