import numpy as np
def _prox_w21_norm(W,param):
task_weight = np.linalg.norm(W, axis=1)
zero_idx = task_weight != 0
radial_mat = np.zeros(W.shape)
radial_mat[zero_idx] = W[zero_idx]/task_weight[zero_idx,np.newaxis]
updated_mat = W - param*radial_mat
idx = W*updated_mat <= 0
updated_mat[idx] = 0
return updated_mat
この際最適化をForward backward splitting(ダッチさん2009)でといていると反復の途中で
<class 'numpy.float64'>
<class 'float'>
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-131-6d534db85f86> in <module>()
1 A = martins()
----> 2 B = martins_smooth()
3 pl.loglog(MARTINS_OBJ, label='MARTINS')
4 pl.loglog(MARTINS_SMOOTH_OBJ,label='MARTINS SMOOTH')
5 pl.legend()
<ipython-input-129-d92ef3b97804> in martins_smooth()
5 MARTINS_SMOOTH_STEP.clear()
6 for t in range(LOOP):
----> 7 obj = objective(W)
8 MARTINS_SMOOTH_OBJ.append(obj)
9 gamma, W = linesearch(W, gamma)
<ipython-input-126-e6b93142ddc4> in objective(W)
1 def objective(W):
----> 2 return l2_loss(W)+W21_PARAM*w21_norm(W)+W12_PARAM*w21_norm(W.T)
<ipython-input-125-a9ca0e6faaa4> in w21_norm(W)
1 def w21_norm(W):
----> 2 return np.sum(np.linalg.norm(W, axis=1))
/usr/local/lib/python3.5/dist-packages/numpy/linalg/linalg.py in norm(x, ord, axis, keepdims)
2158
2159 s = (x.conj() * x).real
-> 2160 return sqrt(add.reduce(s, axis=axis, keepdims=keepdims))
2161 else:
2162 try:
AttributeError: 'float' object has no attribute 'sqrt'
となり,エラーを吐かれてしまうという謎の現象に直面した.
反復回数が数千回目で発生するので実装に間違いはなさそう.
暗黙の型変換が起きてこのようなエラーが吐かれてしまうと考え,numpyの実装からnumpy.linalg.normの自分が使うところのみを修正,改善して対応した.
from numpy.core import add
def norm(x, axis=None):
tmp = add.reduce((x.conj() * x).real, axis=axis)
if type(tmp) == float:
return math.sqrt(tmp)
else:
return np.sqrt(tmp.astype(np.float64))
array.astype(np.float32)
とやればエラーに直面しない