target network

Et separat neuralt netværk, der bruges som et stabilt mål under Q-læring for at reducere svingninger i træningen.

Kort fortalt

Et kopi af hovednetværket, der opdateres sjældnere for at gøre træningen mere stabil.

Kategori
teknik
Niveau
øvet
Udtale
/ˈtɑːrɡɪt ˈnɛtwɜːrk/

Betydninger

1
  1. 1

    Et neuralt netværk, der holdes fast under træning af et hovednetværk, og som periodisk opdateres til at matche hovednetværkets parametre, for at give stabile målværdier i bootstrapping.

    • I DQN opdateres target-netværket hvert C antal steps ved at kopiere vægtene fra hovednetværket.Mnih et al., 2015

Hvornår bruges det

Target network bruges i deep Q-learning-algoritmer som DQN. Hovednetværket lærer at forudsige Q-værdier, mens target-netværket genererer de målværdier, som hovednetværket trænes mod. Target-netværkets vægte kopieres periodisk fra hovednetværket med faste intervaller eller via blød opdatering.

Formel

DQN mål-Q-værdi: y = r + γ * max_a' Q_target(s', a')

Kodeeksempel

class DQN:
    def __init__(self):
        self.q_network = QNetwork()
        self.target_network = QNetwork()
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.update_frequency = 100
        self.step = 0

    def update_target(self):
        if self.step % self.update_frequency == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

    def train_step(self, batch):
        # ... training code
        y = rewards + gamma * self.target_network(next_states).max(dim=1)[0] * (1 - dones)
        loss = F.mse_loss(self.q_network(states).gather(1, actions), y)
        self.step += 1
        self.update_target()

Eksempel på target network i en DQN-klasse: target-netværket opdateres hver 100. træningsstep ved at kopiere vægtene fra Q-netværket.

Oprindelse

Termen 'target network' blev introduceret i forbindelse med Deep Q-Network (DQN) af Mnih et al. (2015), hvor det blev brugt til at stabilisere træningen.

Afledte ord

2

Kilder

1