Tengo un tensor de Pytorch mask
de dimensiones,
torch.Size([8, 24, 24])
con valores únicos,
> torch.unique(mask, return_counts=True)
(tensor([0, 1, 2]), tensor([2093, 1054, 1461]))
Deseo reemplazar aleatoriamente el número de 2 a 0, de modo que los valores únicos y los recuentos en el tensor se conviertan en,
> torch.unique(mask, return_counts=True)
(tensor([0, 1, 2]), tensor([2500, 1054, 1054]))
He intentado usar torch.where
sin éxito. ¿Cómo se puede lograr esto?
Solución del problema
Una de las posibles soluciones es aplanando vía view
y numpy.random.choice
:
from numpy.random import choice
idx = torch.where(mask.view(-1) == 2)[0] # get all indicies of 2 in flat tensor
num_to_change = 2500 - 2093 # as follows from example abow
idx_to_change = choice(idx, size=num_to_change, replace=False)
mask.view(-1)[idx_to_change] = 0
No hay comentarios.:
Publicar un comentario