2021年3月12日星期五

How to access class object when I use torch.nn.DataParallel()?

I want to train my model using PyTorch with multiple GPUs. I included the following line:

model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)  

Then, I tried to access the optimizer that was defined in my model definition:

G_opt = model.module.optimizer_G  

However, I got an error:

AttributeError: 'DataParallel' object has no attribute optimizer_G

I think it is related with the definition of optimizer in my model definition. It works when I use single GPU without torch.nn.DataParallel. But it does not work with multi GPUs even though I call with module and I could not find the solution.

Here is the model definition:

class MyModel(torch.nn.Module):      ...     self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))     

I used Pix2PixHD implementation in GitHub if you want to see the full code.

Thank you, Best.

Edit: I solved the problem by using model.module.module.optimizer_G.

https://stackoverflow.com/questions/66607905/how-to-access-class-object-when-i-use-torch-nn-dataparallel March 13, 2021 at 06:13AM

没有评论:

发表评论