添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
有情有义的大白菜  ·  python ...·  5 天前    · 
失眠的烤红薯  ·  python qt textBrowser ...·  3 天前    · 
谦虚好学的石榴  ·  PyQt/QTextEdit/Highlig ...·  3 天前    · 
完美的馒头  ·  python QTreeWidget ...·  2 天前    · 
高大的卤蛋  ·  Prevent Cross-Site ...·  2 天前    · 
跑龙套的皮蛋  ·  xml.etree.ElementTree ...·  3 周前    · 
逼格高的麻辣香锅  ·  VS ...·  5 月前    · 
慷慨大方的竹笋  ·  北京市广播电视局·  6 月前    · 

A collaborator sent me a Python pickle, which contains (among other things) a Torch tensor. For example, say they ran this on their computer:

import torch
import pickle
exple = {'some_number': 1, 'some_tensor': torch.tensor([[1., 2., 3.]])}
file = open('my_tensor.pkl', 'wb')
pickle.dump(exple, file)
file.close()

Now, on my computer, I can load the file like this:

#read_pickle.py
import pickle
def read_pickle(filepath):
  file = open(filepath, 'rb')
  content = pickle.load(file)
  file.close()
  return content

R code:

reticulate::use_virtualenv("torch")
reticulate::source_python("read_pickle.py")
data <- read_pickle("my_tensor.pkl")
#> $some_number
#> [1] 1
#> $some_tensor
#> tensor([[1., 2., 3.]])

I can easily convert the non-Torch part to standard R objects:

as.data.frame(data[1])
#>   some_number
#> 1           1

but I can't seem to do the same for the tensor:

as.data.frame(data)
#> Error in as.data.frame.default(x[[i]], optional = TRUE) : 
#>   cannot coerce class ‘c("torch.Tensor", "torch._C.TensorBase", "python.builtin.object"’ to a data.frame
torch::as_array(data$some_tensor)
#> Error in UseMethod("as_array", x) : 
#>   no applicable method for 'as_array' applied to an object of class "c('torch.Tensor', 'torch._C.TensorBase', 'python.builtin.object')"

For that last one, it seems a tensor created in R or imported from pickle do not have the same class:

class(data$some_tensor)
#> [1] "torch.Tensor"          "torch._C.TensorBase"   "python.builtin.object"
t <- torch::torch_tensor(1:3)
class(t)
#> [1] "torch_tensor" "R7"  
torch::as_array(t)
#> [1] 1 2 3

So, is there an "easy" way to read that Torch tensor as an R object?

A not-very-satisfying solution is to modify the Python reading function to convert Torch tensors to Numpy arrays before returning to R:

# read_pickle.py
import pickle
import torch
def read_pickle(filepath):
  file = open(filepath, 'rb')
  content = pickle.load(file)
  content = {k:(v.numpy() if isinstance(v, torch.Tensor) else v) for (k,v) in content.items()}
  file.close()
  return content
              

Torch tensors have a .numpy() method, which you can call to convert them to numpy arrays, which can be converted by reticulate to R arrays. There are a few ways to make this work. You could register an S3 method for py_to_r to make this work globally, or you can just convert it manually after unpickeling. E.g.,:

registerS3method("py_to_r", "torch.Tensor", function(x) x$numpy(), asNamespace("reticulate"))
x <- reticulate::py_load_object("my_tensor.pkl")
x <- rapply(list(x), \(x) x$numpy(), classes = "torch.Tensor", how = "replace")[[1]]

The R package torch is not the same thing as using torch through reticulate. The R torch package wraps the C++ torch library and provides it's own R wrappers. Reticulate embeds a Python interpreter in an R session, and using torch through reticulate is the same as using it through Python interface (i.e, using pytorch)