添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

I’m new to PyTorch. So I’m trying to figure out how use torch.compile in code I haven’t developed.

I have this for example:

    net = UNet(n_channels=1, n_classes=1)
    net.load_state_dict(torch.load(weights[feature], map_location=device))

So I tried:

    net = torch.compile(UNet(n_channels=1, n_classes=1))
    net.load_state_dict(torch.load(weights[feature], map_location=device))

But execution will fail with:

RuntimeError: Error(s) in loading state_dict for OptimizedModule

Inside my class UNet(nn.Module) I have instances for several other classes using nn.Module, so I thought I could use torch.compile there as well.

I’ve seen the new tutorials in PyTorch website but nothing specific to my case yet.

I’m not sure I would do it. Here’s the main part of the code using instance net:

from rcia_tools.helpers.for_torch import UNet, predict_bscans, torch
net = UNet(n_channels=1, n_classes=1)
net.load_state_dict(torch.load(weights[feature], map_location=device))
net.to(device)
for group_id in base.groups:
    for idx in range(base.size[group_id][1]):
        mask_probabilities = predict_bscans(net, np.array(bscans), device, batch_num=n_size)

In for_torch.py:

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64 * factor, bilinear)
        self.outc = OutConv(64, n_classes)
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
class Up(nn.Module):
    """Upscaling then double conv"""
class OutConv(nn.Module):
def predict_bscans(unet, bscans, device, batch_num=2):
    """Tiles and Segments Bscan"""
...```
This is essentially all the PyTorch part in my code. I'm just wondering how I could apply `torch.compile()` to check if I can benefit from it. Or any other particular optimisations.
net = torch.compile(net)

And that worked. However, I can’t really say if it got faster. I don’t have a proper benchmark and I’m using a busy platform (shared with other colleagues). That said, I did a few runs with and without compile and It seems to save like 10 to 30s: ~120s (compiled), ~150s (not compiled). But it may saving more on the long run.

BTW, should torch.compile() make a difference for CPU only cases? I tried her but saw some inconclusive messages like:

No CUDA runtime is found, using CUDA_HOME='/usr'

yet the code worked. I didn’t properly benchmark it, but my first impression it didn’t make any difference.

There’s some nuance with how to benchmark PT 2.0, you can take a look at what they are here mlsys-experiments/benchmark-pt2.py at main · msaroufim/mlsys-experiments · GitHub - planning on upstreaming that script this week to core

CPU should still see speedups but the type of CPU might have an impact, the most dramatic speedups will be on GPU with tensor cores