添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
import numpy as np
def _prox_w21_norm(W,param):
    task_weight = np.linalg.norm(W, axis=1)
    zero_idx = task_weight != 0
    radial_mat = np.zeros(W.shape)
    radial_mat[zero_idx] = W[zero_idx]/task_weight[zero_idx,np.newaxis]
    updated_mat = W - param*radial_mat
    idx = W*updated_mat <= 0
    updated_mat[idx] = 0
    return updated_mat

この際最適化をForward backward splitting(ダッチさん2009)でといていると反復の途中で

<class 'numpy.float64'> # W の要素
<class 'float'> # W の要素
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-131-6d534db85f86> in <module>()
      1 A = martins()
----> 2 B = martins_smooth()
      3 pl.loglog(MARTINS_OBJ, label='MARTINS')
      4 pl.loglog(MARTINS_SMOOTH_OBJ,label='MARTINS SMOOTH')
      5 pl.legend()
<ipython-input-129-d92ef3b97804> in martins_smooth()
      5     MARTINS_SMOOTH_STEP.clear()
      6     for t in range(LOOP):
----> 7         obj = objective(W)
      8         MARTINS_SMOOTH_OBJ.append(obj)
      9         gamma, W = linesearch(W, gamma)
<ipython-input-126-e6b93142ddc4> in objective(W)
      1 def objective(W):
----> 2     return l2_loss(W)+W21_PARAM*w21_norm(W)+W12_PARAM*w21_norm(W.T)
<ipython-input-125-a9ca0e6faaa4> in w21_norm(W)
      1 def w21_norm(W):
----> 2     return np.sum(np.linalg.norm(W, axis=1))
/usr/local/lib/python3.5/dist-packages/numpy/linalg/linalg.py in norm(x, ord, axis, keepdims)
   2158             # special case for speedup
   2159             s = (x.conj() * x).real
-> 2160             return sqrt(add.reduce(s, axis=axis, keepdims=keepdims))
   2161         else:
   2162             try:
AttributeError: 'float' object has no attribute 'sqrt'

となり,エラーを吐かれてしまうという謎の現象に直面した.
反復回数が数千回目で発生するので実装に間違いはなさそう.
暗黙の型変換が起きてこのようなエラーが吐かれてしまうと考え,numpyの実装からnumpy.linalg.normの自分が使うところのみを修正,改善して対応した.

from numpy.core import add 
def norm(x, axis=None):
    tmp = add.reduce((x.conj() * x).real, axis=axis)
    if type(tmp) == float:
        return math.sqrt(tmp)
    else:
        return np.sqrt(tmp.astype(np.float64))
    
array.astype(np.float32)

とやればエラーに直面しない