# Source code for netrd.reconstruction.maximum_likelihood_estimation

"""
maximum_likelihood_estimation.py
---------------------
Reconstruction of graphs using maximum likelihood estimation
author: Brennan Klein
email: brennanjamesklein at gmail dot com
submitted as part of the 2019 NeTSI Collabathon
"""
from .base import BaseReconstructor
import numpy as np
from ..utilities import create_graph, threshold

[docs]class MaximumLikelihoodEstimation(BaseReconstructor): """Uses maximum likelihood estimation."""
[docs] def fit(self, TS, rate=1.0, stop_criterion=True, threshold_type='degree', **kwargs): """Infer inter-node coupling weights using maximum likelihood estimation methods. The results dictionary also stores the weight matrix as 'weights_matrix' and the thresholded version of the weight matrix as 'thresholded_matrix'. Parameters ---------- TS (np.ndarray) Array consisting of :math:L observations from :math:N sensors. rate (float) rate term in maximum likelihood stop_criterion (bool) if True, prevent overly-long runtimes threshold_type (str) Which thresholding function to use on the matrix of weights. See netrd.utilities.threshold.py for documentation. Pass additional arguments to the thresholder using '**kwargs'. Returns ------- G (nx.Graph or nx.DiGraph) a reconstructed graph. References ---------- .. [1] https://github.com/nihcompmed/network-inference/blob/master/sphinx/codesource/inference.py """ N, L = np.shape(TS) # N nodes, length L rate = rate / L s1 = TS[:, :-1] W = np.zeros((N, N)) nloop = 10000 for i0 in range(N): st1 = TS[i0, 1:] # time series activity of single node w = np.zeros(N) h = np.zeros(L - 1) cost = np.full(nloop, 100.0) for iloop in range(nloop): dw = np.dot(s1, (st1 - np.tanh(h))) w += rate * dw h = np.dot(s1.T, w) cost[iloop] = ((st1 - np.tanh(h)) ** 2).mean() if stop_criterion and cost[iloop] >= cost[iloop - 1]: break W[i0, :] = w # threshold the network W_thresh = threshold(W, threshold_type, **kwargs) # construct the network self.results['graph'] = create_graph(W_thresh) self.results['weights_matrix'] = W self.results['thresholded_matrix'] = W_thresh G = self.results['graph'] return G