Source code for sfa.algorithms.sp

# -*- coding: utf-8 -*-

import sys
if sys.version_info <= (2, 8):
    from builtins import super

import numpy as np

from .np import NetworkPropagation
from .np import NetworkPropagationParameterSet


[docs]def create_algorithm(abbr): return SignalPropagation(abbr)
# end of def
[docs]class SignalPropagationParameterSet(NetworkPropagationParameterSet):
[docs] def initialize(self): super().initialize()
# end of class ParameterSet
[docs]class SignalPropagation(NetworkPropagation): def __init__(self, abbr): super().__init__(abbr) self._name = "Signal propagation algorithm" # end of def __init__
[docs] def prepare_exact_solution(self): """ Prepare to get the matrix for the exact solution: .. :math x(t+1) = a*_W.dot(x(t)) + (1-a)*b, where $a$ is alpha. When $t -> inf$, both $x(t+1)$ and $x(t)$ converges to the stationary state. Then, s = aW*s + (1-a)b (I-aW)*s = (1-a)b s = (I-aW)^-1 * (1-a)b s = M*b, where M is (1-a)(I-aW)^-1. This method is to get the matrix, M for preparing the exact solution """ W = self._W a = self._params.alpha M0 = np.eye(W.shape[0]) - a*W self._M = (1-a)*np.linalg.inv(M0)
# end of def _prepare_exact_solution
[docs] def prepare_iterative_solution(self): pass # Nothing...
# end of def prepare_iterative_solution
[docs] def propagate_exact(self, b): if self._weight_matrix_invalidated: self.prepare_exact_solution() return self._M.dot(b)
# end of def propagate_exact
[docs] def propagate_iterative(self, W, xi, b, a=0.5, lim_iter=1000, tol=1e-5, get_trj=False): n = W.shape[0] # Initial values #x0 = np.zeros((n,), dtype=np.float) #x0[:] = xi x0 = np.array(xi, dtype=np.float64) x_t1 = x0.copy() if get_trj: # Record the initial states trj_x = [] trj_x.append(x_t1.copy()) # Main loop num_iter = 0 for i in range(lim_iter): # Main formula x_t2 = a*W.dot(x_t1) + (1-a)*b num_iter += 1 # Check termination condition if np.linalg.norm(x_t2 - x_t1) <= tol: break # Add the current state to the trajectory if get_trj: trj_x.append(x_t2) # Update the state x_t1 = x_t2.copy() # end of for if get_trj is False: return x_t2, num_iter else: return x_t2, np.array(trj_x)
# end of def propagate_iterative # end of def class SignalPropagation