Source code for netrd.reconstruction.mean_field

"""
exact_mean_field.py
---------------------
Reconstruction of graphs using the exact mean field
author: Brennan Klein
email: brennanjamesklein at gmail dot com
submitted as part of the 2019 NetSI Collabathon
"""
from .base import BaseReconstructor
import numpy as np
from scipy import linalg
from scipy.integrate import quad
from scipy.optimize import fsolve
from ..utilities import create_graph, threshold


[docs]class MeanField(BaseReconstructor):
[docs] def fit( self, TS, exact=True, stop_criterion=True, threshold_type='range', **kwargs ): """Infer inter-node coupling weights using a mean field approximation. From the paper: "Exact mean field (eMF) is another mean field approximation, similar to naive mean field and thouless anderson palmer. We can improve the performance of this method by adding our stopping criterion. In general, eMF outperforms nMF and TAP, but it is still worse than FEM and MLE, especially in the limit of small sample sizes and large coupling variability." For details see [1]_. The results dictionary also stores the weight matrix as `'weights_matrix'` and the thresholded version of the weight matrix as `'thresholded_matrix'`. Parameters ---------- TS (np.ndarray) Array consisting of :math:`L` observations from :math:`N` sensors. exact (bool) If True, use the exact mean field approximation. If False, use the naive mean field approximation. stop_criterion (bool) If True, prevent overly-long runtimes. Only applies for exact mean field. threshold_type (str) Which thresholding function to use on the matrix of weights. See `netrd.utilities.threshold.py` for documentation. Pass additional arguments to the thresholder using ``**kwargs``. Returns ------- G (nx.Graph or nx.DiGraph) a reconstructed graph. References ---------- .. [1] https://github.com/nihcompmed/network-inference/blob/master/sphinx/codesource/inference.py """ N, L = np.shape(TS) # N nodes, length L m = np.mean(TS, axis=1) # empirical value # A matrix A = 1 - m**2 A_inv = np.diag(1 / A) A = np.diag(A) ds = TS.T - m # equal time correlation C = np.cov(ds, rowvar=False, bias=True) C_inv = linalg.inv(C) s1 = TS[:, 1:] # one-step-delayed correlation ds1 = s1.T - np.mean(s1, axis=1) D = cross_cov(ds1, ds[:-1]) # predict naive mean field W: B = np.dot(D, C_inv) if exact: # --------------------------------------------------------------- fun1 = ( lambda x, H: (1 / np.sqrt(2 * np.pi)) * np.exp(-(x**2) / 2) * np.tanh(H + x * np.sqrt(delta)) ) fun2 = ( lambda x: (1 / np.sqrt(2 * np.pi)) * np.exp(-(x**2) / 2) * (1 - np.square(np.tanh(H + x * np.sqrt(delta)))) ) W = np.empty((N, N)) nloop = 100 for i0 in range(N): cost = np.zeros(nloop + 1) delta = 1.0 def integrand(H): """ Return the integrand of this function """ y, err = quad(fun1, -np.inf, np.inf, args=(H,)) return y - m[i0] for iloop in range(1, nloop): H = fsolve(integrand, 0.0) H = float(H) a, err = quad(fun2, -np.inf, np.inf) a = float(a) if a != 0: delta = (1 / (a**2)) * np.sum( (B[i0, :] ** 2) * (1 - m[:] ** 2) ) W_temp = B[i0, :] / a H_temp = np.dot(TS[:, :-1].T, W_temp) cost[iloop] = np.mean((s1.T[:, i0] - np.tanh(H_temp)) ** 2) if stop_criterion and cost[iloop] >= cost[iloop - 1]: break W[i0, :] = W_temp[:] else: W = np.dot(A_inv, B) # threshold the network W_thresh = threshold(W, threshold_type, **kwargs) # construct the network self.results['graph'] = create_graph(W_thresh) self.results['weights_matrix'] = W self.results['thresholded_matrix'] = W_thresh G = self.results['graph'] return G
def cross_cov(a, b): """ cross_covariance a,b --> <(a - <a>)(b - <b>)> (axis=0) """ da = a - np.mean(a, axis=0) db = b - np.mean(b, axis=0) return np.matmul(da.T, db) / a.shape[0]