A robust estimator of mutual information for deep learning interpretability

We develop the use of mutual information (MI), a well-established metric in information theory, to interpret the inner workings of deep learning models. To accurately estimate MI from a finite number of samples, we present GMM-MI (pronounced $``$Jimmie$"$), an algorithm based on Gaussian mixture models that can be applied to both discrete and continuous settings. GMM-MI is computationally efficient, robust to the choice of hyperparameters and provides the uncertainty on the MI estimate due to the finite sample size. We extensively validate GMM-MI on toy data for which the ground truth MI is known, comparing its performance against established mutual information estimators. We then demonstrate the use of our MI estimator in the context of representation learning, working with synthetic data and physical datasets describing highly non-linear processes. We train deep learning models to encode high-dimensional data within a meaningful compressed (latent) representation, and use GMM-MI to quantify both the level of disentanglement between the latent variables, and their association with relevant physical quantities, thus unlocking the interpretability of the latent representation. We make GMM-MI publicly available.


I. INTRODUCTION
The flexibility and expressiveness of deep learning (DL) models are attractive features, which have led to their application to a variety of scientific problems (see e.g. Raghu and Schmidt [1] for a recent review). Despite this recent progress, deep neural networks remain opaque models, and their power as universal approximators [2][3][4] comes at the expense of interpretability [5]. Many techniques have been developed to gain insight into such black-box models [6][7][8][9][10][11][12][13]. These solutions vary in their computational efficiency and in the range of tasks to which they can be applied; however, there is no consensus as to which method provides the most trustworthy interpretations, and a general framework to interpret deep neural networks is still an avenue of active investigation (see e.g. Li et al. [14], Linardatos et al. [15] for recent reviews).
DL models are also widely used in representation learning, where a high-dimensional dataset is compressed to a smaller set of variables; this latent representation should contain all the relevant information for downstream tasks such as reconstruction, classification or regression [16,17]. Disentanglement of these compressed variables is also often imposed, in order to associate each latent to a physical quantity of domain interest [17][18][19][20][21][22][23][24]. However, how best to access the information captured by these latent vectors and connect it to the relevant factors remain open questions.
In this work, we focus on representation learning and link the latent variables to relevant physical quantities by estimating their mutual information (MI), a well-established informationtheoretic measure of the relationship between two random variables. MI allows us to interpret what the DL model has learned about the domain-specific parameters relevant to the problem: by interrogating the model through MI, we aim to discover what information is used by the model in making predictions, thus achieving the interpretation of its inner workings. We also use MI to quantify the level of disentanglement of the latent variables.
MI has found applications in a variety of scientific fields, including astrophysics [25][26][27][28][29][30][31][32][33], biophysics [34][35][36][37][38][39][40], and dynamical complex systems [41][42][43][44][45][46][47][48], to name a few. However, estimating the mutual information I(X, Y ) between two random variables X and Y , given samples from their joint distribution p (X,Y ) , remains a long-standing challenge, since it requires p (X,Y ) to be known or estimated accurately [49,50]. When X and Y are continuous variables with values over X × Y, I(X, Y ) is defined as: where p X and p Y are the marginal distributions of X and Y , respectively, and ln refers to the natural logarithm, so that MI is measured in natural units (nat). I(X, Y ) represents the amount of information one gains about Y by observing X (or vice versa): it captures the full dependence between two variables going beyond the Pearson correlation coefficient, since I(X, Y ) = 0 if and only if X and Y are statistically independent [51]. A comprehensive summary of MI and its properties can be found in Vergara and Estévez [50].
The most straightforward estimator of I(X, Y ) given samples of p (X,Y ) consists of binning the data and approximating Eq. (1) with a finite sum over the bins. This approach is heavily dependent on the binning scheme, and is prone to systematic errors [39,[52][53][54][55][56][57][58][59]. Kraskov et al. [56] proposed an estimator (hereafter referred to as KSG), based on k-nearest neighbors, which rewrites I(X, Y ) in terms of the Shannon entropy, and then applies the Kozachenko-Leonenko entropy estimator [60]  Our algorithm uses a GMM with c components to obtain a fit of the joint distribution p (X,Y ) : where θ is the set of weights w 1:c , means µ 1:c and covariance matrices Σ 1:c . With this choice, the marginals p(x) and p(y) are also GMMs, with parameters determined by θ. Our procedure for estimating MI and its associated uncertainty is as follows.
1. For a given number of GMM components c, we randomly initialize n init different GMM models. Each set of initial GMM parameters is obtained by first randomly assigning the responsibilities, namely the probabilities that each point belongs to a component i, sampling from a uniform distribution. The starting values of each µ i and Σ i are calculated as the sample mean and covariance matrix of all points, weighted by the responsibilities, while each w i is initialized as the average responsibility across all points. Having multiple initializations is crucial to reduce the risk of stopping at local optima during the optimization procedure [87][88][89][90][91].
2. We fit the data using k-fold cross-validation: this means that we train a GMM on k − 1 subsets of the data (or "folds"), and evaluate the trained model on the remaining validation fold. Each fit is performed with the expectation-maximization algorithm [92], and terminates when the change in log-likelihood on the training data is smaller than a chosen threshold. We also add a small regularization constant ω to the diagonal of each covariance matrix, as described e.g. in Melchior and Goulding [91], to avoid singular covariance matrices.
3. We select the model with the highest mean validation log-likelihood across foldsˆ c , since it has the best generalization performance. Among the k models corresponding toˆ c , we also store the final GMM parameters with the highest validation log-likelihood on a single fold: these will be used to initialize each bootstrap fit in step 5, thus reducing the risk of stopping at local optima and significantly accelerating convergence.
4. We repeat steps 1-3 iteratively increasing the number of GMM components from c = 1. We stop whenˆ c −ˆ c−1 is smaller than a user-specified positive threshold, and select the value of c − 1 as the optimal number of GMM components to fit. In this way, we avoid overfitting the training data and adding too many components, which would considerably slow down the procedure while not significantly improving the density estimation.
5. We bootstrap the data n b times, and fit a GMM to each bootstrapped realization. Each fit is initialized with the set of parameters selected in step 3, and with the number of components  A flowchart summarizing the GMM-MI procedure is shown in Fig. 1. We choose the initialization procedure described in step 1 for its speed, but in our implementation of GMM-MI other initialization procedures are also available and could be alternatively used. For instance, it is possible that the random initialization we set as default returns overlapping components which inhibit the optimization procedure; in those cases, we recommend switching to an initialization based on k-means [93]. On the other hand, k-means itself is known to only guarantee convergence to local optima [94]; for this reason, we also provide the possibility to perturb the means by a user-specified scale after an initial call to k-means. We call this approach "randomized k-means", and offer full flexibility to select the most appropriate initialization type based on the data being analyzed.
Our implementation also allows the user to set a higher patience, i.e. consider more than one additional component in step 4 after the validation loss has started to decrease; alternatively, it is possible to select the number of components yielding the lowest Akaike information criterion (AIC, [95]) or Bayesian information criterion (BIC, [96]), with details in Appendix A. All three methods implemented are computationally efficient, and aim to prevent the model from overfitting the available samples; in Fig In many instances, the factors of variation that are used to generate the data are discrete variables [97]; in these cases, we will need to estimate MI between a continuous variable X and a categorical variable F which can take v different values f 1:v . In this case, assuming the v values have equal probability (as will be the case when considering the 3D shapes dataset in Sect. IV A), the mutual information I(X, F ) can be expressed as: where we use a GMM to fit each conditional probability p (X|F ) (x|f i ). The full derivation of Eq. (3) can be found in Appendix B.

B. Alternative estimators
In order to validate our algorithm, we compare it with two established estimators of MI. The KSG estimator, first proposed in Kraskov et al. [56], rewrites MI as: where H(·) refers to the Shannon entropy, defined for a single variable as: The Kozachenko-Leonenko estimator [60] is then used to evaluate the entropy in Eq. (5): where ψ(·) is the digamma function, k is the chosen number of nearest neighbors, N is the number of available samples, d is the dimensionality of X, c d is the volume of the unit ball in d dimensions, and (k) is twice the distance between the i th data point and its k th neighbor. Applying Eq. (6) to each term in Eq. (4) would lead to biased estimates of MI [39,56]; for this reason, the KSG estimator actually considers a ball containing the k-nearest neighbors around each sample, and counts the number of points within it in both the x and y direction. The resulting estimator of MI then becomes [39,56]: where n (k) x (n (k) y ) represents the number of points in the x (y) direction, and · indicates the mean over the available samples. In our experiments, we consider the implementation of the KSG estimator available from SKLEARN in this https link.
We also compare our algorithm against the MINE estimator proposed in Belghazi et al. [69].
MI as defined in Eq. (1) can be interpreted as the KL divergence D KL between the joint distribution and the product of the marginals: where the KL divergence between two generic probability distributions p X and q X defined over X is defined as: The MINE estimator then considers the Donsker-Varadhan representation [78] of the KL divergence: where the supremum is taken over all the functions T such that the expectations E [·] are finite, and parameterizes T with a neural network. In our experiments, we consider the implementation available in this https link, which includes the mitigation of the gradient bias through the use of an exponential moving average, as suggested in Belghazi et al. [69].

C. Representation learning
We apply our MI estimator GMM-MI to interpret the latent space of representation-learning models. Specifically, we consider β-variational autoencoders (β-VAEs, [21,98]), where one neural network is trained to encode high-dimensional data D into a distribution over disentangled latent variables z, and a second network decodes samples of the latent distribution back into data points D. The two networks are trained together to minimize the following loss function: where MSE indicates the mean squared error, p φ (z|D) represents the encoder parameterized by a set of weights φ, p(z) is the prior over the latent variables z, and β is a regularization constant which controls the level of disentanglement of z.
We will also reproduce the results of Lucie-Smith et al. [32] in Sect. IV B, for which the architecture is slightly different: the latent samples are combined with a given query (the radius r) and fed through the decoder to predict dark matter halo density profiles at each given r. This model is referred to as the interpretable variational encoder (IVE), with an analogous loss function to Eq. (11).

III. VALIDATION
In this section, we validate GMM-MI on toy data for which the MI can be computed analytically: we show that GMM-MI is in good agreement with the ground truth, as well as other MI estimators, while returning the full distribution of MI including its uncertainty. We run all the MI estimations on a single CPU node with 40 2.40GHz Intel Xeon Gold 6148 cores using no more than 300 MB of RAM, reporting the speed performance in each case.   We first consider a bivariate Gaussian distribution with unit variance of each marginal and varying level of correlation ρ ∈ [−1, 1], following Belghazi et al. [69]. In this case, the true value of I(X, Y) can be obtained analytically by solving the integral in Eq. (1), yielding: We consider two additional bivariate distributions, the gamma-exponential distribution [54,56,99,100], with density (α > 0 is a free parameter): where Γ is the gamma function, and the ordered Weinman exponential distribution [54,56,99,100], with density: The true value of I(X, Y ) for these distributions can be obtained analytically, and is reported in Appendix C. Since I(X, Y ) is invariant under invertible transformations of each random variable [56], we consider ln(X) and ln (Y ) when estimating MI in the case of the last two distributions [56]. To demonstrate the power of our estimator, we restrict ourselves to the case with only The results are reported in Fig. 2. The KSG estimator is the fastest, and yields MI values closely matching the ground truth, but returns biased estimates around e.g. |ρ| = 0.4 in the bivariate Gaussian case, and α 1 in the ordered Weinman case. The MINE estimator is more computationally expensive and shows a relatively high variance, which is expected since MINE has been shown to be prone to variance overestimation due to the use of batches [72]. GMM-MI, on the other hand,   We further validate GMM-MI by testing that it is unbiased, and that the estimated MI variance scales as N −1 , when the number of available samples N increases. We additionally show that GMM-MI satisfies the MI property of invariance under invertible non-linear transformations [56].
We consider a bivariate Gaussian distribution with ρ = 0.6, and three different functions applied to one marginal variable Y : f (y) = y (identity), f (y) = y + 0.5y 3 (cubic) and f (y) = ln (y + 5.5) (logarithmic). To deal with these datasets, we change the GMM-MI hyperparameters to k = 3, n init = 5, and M = 10 5 ; however, we find no significant variations in the results even with different sets of hyperparameters. We repeat the estimation procedure of MI 500 times, drawing N samples with a different seed every time, and considering N = 200, N = 2 000 and N = 20 000. For each estimate, we calculate the bias, i.e. the difference between the estimated value of MI and the ground truth.
We report violin plots of the bias and of the MI standard deviation as returned by GMM-MI across the 500 trials in Fig. 3. The mean bias, indicated as a black cross, converges to 0 as N grows, and it is always well below the typical value of the standard deviation, thus demonstrating that GMM-MI is unbiased. This is true even when considering the cubic and the logarithmic transformations, further confirming that GMM-MI correctly captures the invariance property of MI. Moreover, in all cases the standard deviations returned by GMM-MI follow a power law

A. A note on bootstrap
As reported in Holmes and Nemenman [39], using bootstrap to associate an error bar to MI estimates can lead to catastrophic failures: duplicate points can be interpreted as fine-scale features, introducing spurious extra MI. In this section, we address this concern and empirically show that, despite including a bootstrap step, our procedure does not lead to biased estimates of MI.
We consider the same experiment described in Holmes and Nemenman [39], where a single data set of N = 200 bivariate Gaussian samples with ρ = 0.6 is bootstrapped 20 times. We apply the KSG (with 3 neighbors, following Holmes and Nemenman [39]) and MINE estimators to each bootstrapped realization, and compare it against our estimator with n b = 20. The results are reported in Fig. 4. The KSG estimator returns a mean MI biased by a factor of 4, while both MINE and our procedure return an accurate estimate. However, MINE is two orders of magnitude more computationally demanding, and returns an error bar which is larger than with our procedure, since it tends to overestimate the variance, as discussed in Sect. III.

IV. RESULTS
In this section, we apply our estimator to interpret the latent space of representation-learning models trained on three different datasets, ranging from synthetic images to cosmological simulations. We use our MI estimator to quantify the level of disentanglement of latent variables, and link them to relevant physical parameters. In the following experiments, we consider k = 3 folds, n init = 5 different initializations, a log-likelihood threshold on each individual fit of 10 −5 , n b = 100 bootstrap realizations, M = 10 5 MC samples, and a regularization scale of ω = 10 −15 ; as in the experiments described in Sect. III, we found GMM-MI to be robust to the hyperparameter choices. Obtaining the full distribution of MI with our algorithm typically takes O(10) s on the datasets we analyze, using the same hardware described in Sect III.

A. 3D Shapes
We consider the 3D Shapes dataset [101,102], which consists of images of various shapes that were generated by the following factors: shape (4 values Each latent is dependent upon only a single factor, except for z 2 and z 4 , which appear entangled with scale and shape, as also found in Kim and Mnih [101]. described in Sect. II C, on this dataset, using a 6-dimensional latent space and setting the value of β using cross-validation. After training, we encode 10% of the data, which were not used for training or validation, I(z 2 , z 4 ) = 0.04 ± 0.01 nat.

B. Dark matter halo density profiles
In the standard model of cosmology, only 5% of our Universe consists of baryonic matter, while the remainder consists of dark matter (25%) and dark energy (70%) [103]. In particular, dark matter only interacts via the gravitational force, and gathers into stable large-scale structures, called 'halos', where galaxy formation typically occurs. Given the highly non-linear physical processes taking place during the formation of such structures, a common tool to analyze dark matter halos are cosmological N -body simulations, where particles representing the dark matter are evolved in a box under the influence of gravity [104][105][106].
Dark matter halos forming within such simulations exhibit a universal spherically-averaged density profile as a function of their radius [107][108][109]; this universality encompasses a huge range of halo masses and persists within different cosmological models. While the universality of the density profile is still not fully understood, Lucie-Smith et al. [32,LS22 hereafter] showed that it is possible to train a deep representation learning model to compress raw dark matter halo data into a compact disentangled representation that contains all the information needed to predict dark matter density profiles. Following LS22, we consider 4332 dark matter halos from a single Nbody simulation, and encode them using their IVE infall model with 3 latent variables. The latent representation is used to predict the dark matter halo density profile in 13 different radial bins.
We calculate the MI between the ground-truth halo density in each radial bin and each latent variable, aiming to reproduce the middle panel of fig. 4 in LS22, where further details can be found. We show the trend of MI for all radial bins and latent variables in Fig. 6. We compare the estimates from GMM-MI with those obtained using kernel density estimation (KDE) with different bandwidths, as done in LS22. A major difference between the two approaches is that our bands

C. Stellar spectra
We consider the model presented in Sedaghat et al. [80, S21 hereafter], where a β-VAE is trained on about 7 000 real unique spectra with a 128-dimensional latent space. These spectra were collected by the High-Accuracy Radial-velocity Planet Searcher instrument (HARPS, [110, 111]) in the spectral range 378-691 nm, and include mainly stellar spectra, even though Jupiter and asteroid contaminants are present in the dataset. All details about the data, the preprocessing steps and the training procedure can be found in S21.
To select the most informative latent variables, the median absolute deviation (MAD) is calculated for each of them; the rest of the analysis is carried out on the six most informative latents only. We calculate MI between each of these six variables and six known physical factors, all treated as continuous variables. These are the star radial velocity, its effective temperature T eff , its mass, its metallicity [M/H], the atmospheric air mass and the signal-to-noise ratio (SNR).  Fig. 6. On the other hand, our algorithm provides a robust way to select the hyperparameters, thus avoiding underfitting or overfitting the samples.

V. CONCLUSIONS
We presented GMM-MI (pronounced "Jimmie"), an efficient and robust algorithm to estimate the mutual information (MI) between two random variables given samples of their joint distribution. Our algorithm uses Gaussian mixture models (GMMs) to fit the available samples, and returns the full distribution of MI through bootstrapping, thus including the uncertainty on MI due to the finite sample size. GMM-MI is demonstrably accurate, and benefits from the flexibility and computational efficiency of GMMs. Moreover, it can be applied to both discrete and continuous settings, and is robust to the choice of hyperparameters.
We extensively validated GMM-MI on toy datasets for which the ground truth MI is known, showing equal or better performance with respect to established estimators like KSG [56] and MINE [69]; we also tested that GMM-MI respects MI invariance under invertible transformations, is unbiased and returns MI errors that scale as expected with sample size. We demonstrated the application of our estimator to interpret the latent space of three different deep representationlearning models trained on synthetic shape images, large-scale structure in cosmological simulations and real spectra of stars. We calculated both the MI between latent variables and physical factors, and the MI between the latent variables themselves, to investigate their degree of disentanglement, reproducing MI estimates obtained with various techniques, including histograms and kernel density estimators. These results further validate the accuracy of GMM-MI and confirm the power of MI for gaining interpretability of deep learning models.
We plan to extend our work by improving the density estimation with more flexible tools such as normalizing flows (NFs, [112,113]), which can be seamlessly integrated into neural networkbased settings and can benefit from graphics processing unit (GPU) acceleration. Moreover, combining NFs with a differentiable numerical integrator would make our estimator amenable to backpropagation, thus allowing its use in the context of MI optimization. We will explore this avenue in future work.

DATA AVAILABILITY STATEMENT
GMM-MI is publicly available in this GitHub repository (https://github.com/dpiras/GMM-MI, also accessible by clicking the icon ), together with all data and results from the paper.

ACKNOWLEDGMENTS
We thank Nima Sedaghat, Martino Romaniello and Vojtech Cvrcek for sharing the stellar spectra model and data. We are also grateful to Justin Alsing for useful discussions about initialization procedures for GMM fitting. DP was supported by the UCL Provost's Strategic Development  where δ is the Dirac delta function, and in the last step we assumed that F can take the values f 1:v with equal probability. Combining the last two equations, we obtain: as reported in Eq. (3).

Appendix C: Ground truth values of mutual information
We report the true values of MI for the bivariate distributions considered in Sect. III. These values can be obtained via direct integration of Eq. (1), and depend on a real-valued parameter α > 0. For the gamma-exponential distribution [54,56,99,100] as defined in Eq. (13): where ψ is the digamma function, defined as: For the ordered Weinman exponential distribution [54,56,99,100] as defined in Eq. (14): (C3)