2021年2月4日星期四

torch: minimally pad tensor such that num elements divisible by x

Suppose I have a tensor t of arbitrary ndim I want to pad (with zeroes) it such that a) I introduce the fewest possible # elements b) after padding, (t.numel() % x) == 0

Is there a better algorithm for doing this than find the largest dimension and increase it by 1 until condition (b) is satisfied?

Maybe working code:

def pad_minimally(t, x):      largest_dim = np.argmax(t.shape)      buffer_shape = list(t.shape)      new_t = t.clone()      print(t.shape)      for n_to_add in range(x):          if new_t.numel() % x == 0:              break          buffer_shape[largest_dim] = n_to_add          new_buffer = torch.zeros(*buffer_shape)          new_t = torch.cat([t, new_buffer], axis=largest_dim)      assert new_t.numel() % x == 0      return new_t  assert pad_minimally(torch.rand(3,1), 7).shape == (7,1)  assert pad_minimally(torch.rand(3,2), 7).shape == (7,2)  assert pad_minimally(torch.rand(3,2, 6), 7).shape == (3,2,7)  
https://stackoverflow.com/questions/66055262/torch-minimally-pad-tensor-such-that-num-elements-divisible-by-x February 05, 2021 at 07:13AM

没有评论:

发表评论