GOOGLE ADS

miércoles, 20 de abril de 2022

Tensor Pytorch: reemplaza aleatoriamente los valores que cumplen la condición

Tengo un tensor de Pytorch maskde dimensiones,

torch.Size([8, 24, 24])

con valores únicos,

> torch.unique(mask, return_counts=True)
(tensor([0, 1, 2]), tensor([2093, 1054, 1461]))

Deseo reemplazar aleatoriamente el número de 2 a 0, de modo que los valores únicos y los recuentos en el tensor se conviertan en,

> torch.unique(mask, return_counts=True)
(tensor([0, 1, 2]), tensor([2500, 1054, 1054]))

He intentado usar torch.wheresin éxito. ¿Cómo se puede lograr esto?


Solución del problema

Una de las posibles soluciones es aplanando vía viewy numpy.random.choice:

from numpy.random import choice
idx = torch.where(mask.view(-1) == 2)[0] # get all indicies of 2 in flat tensor
num_to_change = 2500 - 2093 # as follows from example abow
idx_to_change = choice(idx, size=num_to_change, replace=False)
mask.view(-1)[idx_to_change] = 0

No hay comentarios.:

Publicar un comentario

Flutter: error de rango al acceder a la respuesta JSON

Estoy accediendo a una respuesta JSON con la siguiente estructura. { "fullName": "FirstName LastName", "listings...