Is there a way to vectorize this code in pytorch?
def myfunc(a, b):
"Return a-b if a>b, otherwise return a+b"
if a > b:
return a - b
else:
return a + b
With numpy
I can simply do
@np.vectorize
def myfunc(a, b):
"Return a-b if a>b, otherwise return a+b"
if a > b:
return a - b
else:
return a + b
such that myfunc(a=[1, 2, 3, 4], b=2)
returns array([3, 4, 1, 2])
.
Is there a way to do the same in pytorch?
It seems torch.where()
is the way to go:
a = torch.tensor([1, 2, 3, 4])
b = torch.tensor(2)
torch.where(a>b, a-b, a+b)
the code above returns tensor([3, 4, 1, 2])
.