2021年1月26日星期二

Final step of PyTorch Gradient Accumulation for small datasets

I am training a BERT model on a relatively small dataset and cannot afford to lose any labelled sample as they must all be used for training. Due to GPU memory constraints, I am using gradient accumulation to train on larger batches (e.g. 32). According to PyTorch documentation, gradient accumulation is implemented as follows:

scaler = GradScaler()    for epoch in epochs:      for i, (input, target) in enumerate(data):          with autocast():              output = model(input)              loss = loss_fn(output, target)              loss = loss / iters_to_accumulate            # Accumulates scaled gradients.          scaler.scale(loss).backward()            if (i + 1) % iters_to_accumulate == 0:              # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)                scaler.step(optimizer)              scaler.update()              optimizer.zero_grad()  

However, if you are using e.g. 110 training samples, with batch size 8 and accumulation step 4 (i.e. effective batch size 32), this method would only train the first 96 samples (i.e. 32 x 3), i.e. wasting 14 samples. In order to avoid this, I'd like to modify the code as follows (notice change to the final if statement):

scaler = GradScaler()    for epoch in epochs:      for i, (input, target) in enumerate(data):          with autocast():              output = model(input)              loss = loss_fn(output, target)              loss = loss / iters_to_accumulate            # Accumulates scaled gradients.          scaler.scale(loss).backward()            if (i + 1) % iters_to_accumulate == 0 or (i + 1) == len(data):              # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)                scaler.step(optimizer)              scaler.update()              optimizer.zero_grad()  

Is this correct and really that simple, or will this have any side effects? It seems very simple to me, but I've never seen it done before. Any help appreciated!

https://stackoverflow.com/questions/65842691/final-step-of-pytorch-gradient-accumulation-for-small-datasets January 22, 2021 at 05:41PM

没有评论:

发表评论