添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Is there a way to vectorize this code in pytorch?

def myfunc(a, b):
    "Return a-b if a>b, otherwise return a+b"
    if a > b:
        return a - b
    else:
        return a + b

With numpy I can simply do

@np.vectorize
def myfunc(a, b):
    "Return a-b if a>b, otherwise return a+b"
    if a > b:
        return a - b
    else:
        return a + b

such that myfunc(a=[1, 2, 3, 4], b=2) returns array([3, 4, 1, 2]).

Is there a way to do the same in pytorch?

It seems torch.where() is the way to go:

a = torch.tensor([1, 2, 3, 4])
b = torch.tensor(2)
torch.where(a>b, a-b, a+b)

the code above returns tensor([3, 4, 1, 2]).