CYJAX: A package for Calabi-Yau metrics with JAX

We present the first version of CYJAX, a package for machine learning Calabi-Yau metrics using JAX. It is meant to be accessible both as a top-level tool and as a library of modular functions. CYJAX is currently centered around the algebraic ansatz for the K\"ahler potential which automatically satisfies K\"ahlerity and compatibility on patch overlaps. As of now, this implementation is limited to varieties defined by a single defining equation on one complex projective space. We comment on some planned generalizations.


Introduction
Calabi-Yau (CY) manifolds appear ubiquitously in the context of string theory as extradimensional geometries, since a large fraction of consistent string theory constructions is based on such CY-manifolds. The geometry of these manifolds determines the resulting effective field theories (EFT). Unlike topological properties, such as the particle spectrum, which are accessible with existing mathematical tools, the metric itself is generally unknown for these spaces. However, this provides a gap to connect string theory with its low-energy observables such as the couplings among these particles which are determined by the metric.
In the absence of analytic solutions, several numerical approaches have been proposed to identify metrics of CY-manifolds. In short, there are currently two approaches which can be used to approximate CY metrics: • Using a fixed-point iteration scheme which is widely referred to as Donaldson's algorithm [1][2][3][4][5][6]. This spectral method was developed after previous approaches using finite difference methods appeared unpractical [7].
• Using energy functional minimization where the energy functional measures the Ricci flatness or equivalent quantities such as one given by an appropriate Monge-Ampère loss [8][9][10][11][12][13]. At the moment, this approach is computationally the most efficient.
Obtaining either of the measures involves specialized code which has been developed for several proof of concept examples. To facilitate the development of these methods and their application to other manifolds, we are developing the package CYJAX which aims at providing a modular environment that addresses extra-dimensional metrics, in particular CY metrics. The modularity of the implementation makes it possible to extend and adapt the package for dedicated use cases. Our code is based on JAX [14] which enables significant acceleration from pure python code via the use of jit-compilation and automatic vectorization using vmap 1 . It also allows programmatic generation of efficient code starting from symbolically derived expressions, e.g. using SYMPY. This article is a short note on the first release of CYJAX. The basis of this code was developed in large parts for the Master thesis of one of us [15] and has been used for some experiments in [9]. This package is complementary to the package CYMETRIC [12], making different code design choices and using a different ansatz for the metric. Instead of directly parametrizing it by a neural network, we use an algebraic ansatz which is guaranteed to be Kähler and which always matches on patch overlaps. Compared to CYMETRIC, we can thus use two fewer losses. Further, we focus immediately on learning moduli-dependent approximations of the Calabi-Yau metric, instead of optimizing for single points in moduli space. Beyond this, we hope that independent implementations of several algorithmic aspects may be useful to the community. In future work, we expect direct quantitative comparisons to be straight-forward as JAX (used by us) and Tensorflow/NumPy (used in CYMETRIC) are interoperable to a large extent. This may also allow using aspects from each implementation for a particular task. Importantly, our package provides an independent determination of the metric which can be useful in comparing the performance. 2 The rest of this note is organized as follows. Section 2 introduces the manifolds, our choices for relevant quantities to study metrics, and comments on their respective code implementation. Section 3 showcases a few examples on how CYJAX can be used. In Section 4 we provide an overview of some of the planned functionalities for future releases.

Choices for CY Manifolds
We consider here varieties X ⊂ P d+1 which are defined as the zero-locus of a single homogeneous polynomial: Recall that the homogeneous coordinates z = [z 0 : . . . : z d+1 ] on projective space P d+1 are just the complex coordinates of C d+2 with the identification z ∼ λz for all λ ∈ C \ {0}. The multi-index α ranges over all natural numbers such that d+1 i=0 α i = d + 2, corresponding to all monomials of degree d + 2 : Choosing particular values of coefficients ψ α ∈ C corresponds to choosing the complex structure moduli. The dimension of the variety X in the above example is d. Of particular interest in physics is the case d = 3, where X is called a quintic threefold (referring to the degree of the defining equation and manifold's dimension, respectively). We may restrict which coefficients ψ we allow to be nonzero. A particular example is the so-called Dwork family of quintics, which has a single complex parameter: Note that in this case we have dropped the index on ψ as there is only one nonzero component for α = (1, . . . , 1). One can show that the first Chern class vanishes for the above manifolds (see e.g. [18]). From the work by Yau [19], we thus know that a Ricci flat-metric exists and we want to find a numerical approximation to it. The present case of examples has a single Kähler modulus, h 11 = 1, so the metric's dependence on this modulus is absorbed by an appropriate rescaling. A generalization to the case with multiple Kähler moduli is left for the future.
In CYJAX, varieties of the above type are represented by the VarietySingle class, which are essentially characterized by their defining polynomial. There are multiple ways to specify this polynomial in code, the most convenient of which is to pass it as a SYMPYstyle string expression via VarietySingle.from_sympy. For convenience, the Dwork and Fermat quintic are also readily implemented (Fermat, Dwork).

Choice of coordinates
To study the metric on these varieties, we have to identify suitable coordinates. Below, we describe the choices on how to relate the coordinates on the ambient space with coordinates on the variety. There is no unique choice for going from the d + 2 coordinates in the ambient complex space to proper coordinates on the d-dimensional variety X. One option is to accept the redundancy and use the full set of homogeneous coordinates. However, some geometric quantities have no numeric globally defined representation, and thus a choice about local coordinate patches has to be made. The description below also serves as an overview of the conventions we choose in CYJAX.

Coordinates in ambient space
To remove the scaling ambiguity in homogeneous coordinates, we pick one index (with non-zero entry) and set its value to 1 by rescaling: If the patch index p of the d + 2 homogeneous coordinates is scaled to one and omitted, we denote the remaining d + 1 affine coordinates by z (p) . Computationally, the affine coordinates are represented by an array with d+1 entries together with an integer specifying the patch they are in (i.e. which homogeneous index was scaled to 1). Going from homogeneous coordinates to patch p, we divide by the value of z p . Numerically, it is advantageous to avoid very large values. For the numerically "optimal" patch we thus, by default, choose p such that |z p | is maximal. We still have two choices of how to store local coordinates: 1. We can keep the full array of homogeneous coordinates but always scale the largest value to 1. This allows each coordinate to be represented uniquely by a single array. However, since the patch index is implicit, we could then not force a function to treat numerical inputs as lying in another patch.
2. Alternatively, we can keep an array of affine coordinates together with a patch index. This has the benefit of saving slightly in computational cost, especially where the patch index has to be known explicitly, at the memory cost of carrying around two arrays.
Most functions implemented in CYJAX can be called both with homogeneous and local coordinates on projective space. The cases are distinguished by whether or not a patch index is supplied. When indexing into affine or homogeneous coordinate arrays one has to be somewhat careful, as removing the 'redundant' element with value 1 shifts indices. The index k with respect to the affine coordinates z (p) k will here be referred to as affine index.

Local coordinates on the variety
Going from affine coordinates to coordinates on the variety X itself we must eliminate one additional redundant entry. Given coordinate values z by solving the defining equation. The simplest way to pick coordinates is by choosing one coordinate index that is to be considered "dependent" on the other values. All other coordinate entries are kept.
As a particular example, consider a (d + 1) × (d + 1) matrix g denoting a metric in the ambient projective space. We now want to compute the pullback of this to the variety. We have where i and only range over d values corresponding to the choice of dependent entry. The induced metric can be calculated via induced_metric. If z m p , i.e. index m, is chosen as dependent variable, only ∂z m p /∂z i p is non-trivial (all others being either one or zero). The Jacobians can be computed directly (and automatically) from the defining equation This is implemented using the function jacobian_embed. Inside the respective variety we work with the "optimal" dependent coordinate which is minimizing the entries in the Jacobian.

Sampling points
To numerically evaluate integrals, e.g. to compute the volume, we use a Monte Carlo approximation. We thus need to generate samples that lie on the manifold and have a known distribution. One straight-forward way of generating points on the d-dimensional manifold is to sample d − 1 random complex numbers and solve the defining equation for the last coordinate value. However, we do not a priori know the distribution of these points. Instead, we can sample points as intersections with a line in ambient projective space. The distribution of these points is known [3].

Complex and projective coordinates
If complex values are generated by sampling real and imaginary parts uniformly, points will lie on a square in the complex plane. Uniformly sampling the radius in [0, 1] and the complex angle in [0, 2π] will ensure points lie on the unit disk, however the distribution does not have uniform density. Instead, one should sample uniformly from the disk in R 2 and interpret the coordinates as real and imaginary parts. In order to generate uniformly distributed samples on P n , we can first sample uniform points on the real sphere S 2n+1 represented by 2n + 2 real numbers. By pairing these up into n + 1 complex values we obtain homogeneous coordinates on the projective space. This construction corresponds to P n ∼ = S 2n+1 /U (1). Generating uniform points on the real sphere S 2n+1 can be done efficiently by independently sampling 2n + 2 real numbers from the unit-covariance normal distribution and dividing by their vector norm. Since the normal distribution factorizes as we can efficiently draw from the joint distribution by sampling each component z i independently. The exponent i z 2 i is manifestly invariant under SO(2n + 2) rotations, as desired for a uniform distribution, but the norm |z| is not yet fixed to 1. Dividing the sampled points by their norm puts them on the unit sphere while preserving the SO(2n + 2) symmetry, which therefore gives us a uniform distribution on S 2n+1 .

Points on varieties
To numerically estimate integrals, we make use of the Monte Carlo approximation Here, x a are drawn using some (pseudo-) probabilistic procedure with known density measure dA which is explicitly determined using the determinant of the induced Fubini-Study metric on the variety, i.e. dA = det(i p g FS ). The volume dvol is evaluated via the holomorphic top form which is known a priori [20]. The weights w = dvol/dA, in the final step, are required to correct for the difference in the measures. The sampling method we use here is described in [3]. After uniformly sampling two points p, q ∈ P d+1 , we can define a line p + tq with t ∈ C. Samples on the variety are then given by the intersection of this line with the variety, i.e. by solutions for t such that Q(p + tq) = 0. The density of samples generated in this way is known and given in terms of the Fubini-Study metric.
In our implementation, the variety object contains methods which allow the calculation of samples using this intersection method (sample_intersect) and their weights (sample_intersect_weights). The simplest associated application is to compute the volume via the Monte Carlo approximation (compute_vol).

Algebraic metrics
The primary quantity we try to learn is the Hermitian matrix H in the following algebraic ansatz for the Kähler potential, taken from Donaldson's algorithm: Here, the s α represent a set of homogeneous polynomials of some chosen degree k, which correspond to sections of a line bundle. Geometrically, one can understand a set of N k sections s α as defining an embedding of the variety into the complex projective space of dimension N k −1. The resulting Kähler metric on the variety can then be understood as as the pullback of the higher-dimensional Fubini-Study metric generalized by the Hermitian matrix H [1]. The larger the polynomial degree k, the larger the complete set of sections N k and thus the higher the potential resolution of the algebraic ansatz becomes. For Donaldson's algorithm, there is an additional requirement that the set has to form a basis of line bundle sections on the variety, because the algorithm involves a matrix inversion which is otherwise ill-defined [1]. Any linear combination proportional to the defining polynomial vanishes and thus the full set of homogeneous monomials, for example, does not always form a basis on the variety. For the machine learning application, however, there is no such requirement and we can use the full set of monomials of some chosen degree in ambient projective space, which merely amounts to an over-parameterization. Note, also, that one could use any other set of linear combinations of monomials for the sections. While both of these things can be absorbed into reparametrizations of H, in principle they may influence the particular training behavior and numerics. The implementation in CYJAX explicitly exposes the choice of sections, and custom versions can be added. For convenience, two sets are implemented and can be used directly: MonomialBasisFull and MonomialBasisReduced. Finally, one needs to choose suitable initial values for the Hermitian matrix H for which we present examples in Section 3.
A generalization to other metric ansätze is left for future releases.

Geometric quantities
Internally, the computation of geometric objects is contained in a computational graph. This avoids duplicate code and allows easier testing of intermediate values.
Given the choice of variety, the choice of Kähler potential, and the ability to sample points, one is now ready to calculate various properties. Readily implemented are the calculation of: the ratio of both top-forms eta, the associated σ−accuracy measure eta_accuracy, the Ricci curvature in local coordinates ricci, and the Ricci scalar ricci_scalar.
Internally, some of these quantities contain intermediate steps which can also be accessed. A more detailed overview can be found in Figure 1 and in the documentation. The computation of several quantities is implemented explicitly, but can also be computed using (repeated) automatic differentiation. This provides a correctness check for our implementations. When comparing the timing (cf. Documentation for more details), we find a slight improvement in the explicit implementation.

Machine Learning
The tools discussed in the previous sections can be used to implement multiple machine learning approaches to approximating the Calabi-Yau metric in any framework that works with JAX. Functions for a particular approach to this using Flax [21] have been implemented and are accessible in the cyjax.ml submodule. This includes functions for initializing and working with the Cholesky decomposition of the Hermitian matrix H and a batched sampler class (cholesky_decode, cholesky_from_param, hermi-tian_param_init, BatchSampler).
Regarding sampling, one present constraint is that non-Hermitian eigenvalue finding is not implemented yet on the GPU. That means samples have to be generated using the CPU. We can thus either: 1. Train and sample on the CPU.  All options are available with CYJAX. The batched sampler mentioned above by default generates samples on the CPU and afterwards transfers them to GPU, if available. This allows for efficient on the fly sample generation without repetition. The latter means the problem of overfitting, generally present in supervised learning, can be avoided. There are multiple different losses which effectively measure the Ricci-flatness of the approximated metric. In particular, we use here the so called σ-accuracy and a Monge-Ampère loss L M A , which rely on the property that the Ricci flat metric g gives rise to a volume form which must be proportional to the one given by the holomorphic top form [8].
If Ω is the holomorphic top form, we can define the ratio η = det g Ω∧Ω . The σ accuracy measures the deviation of η from being constant as the integral For training, we use the related variance-like Monge-Ampère loss which approximates an integral with respect to the volume form dvol Ω using Monte Carlo weights w(z). These weights "undo" the bias introduced by the sampling scheme used to sample points z on the manifold, as discussed in Section 2.2. Lastly, there is a configurable MLP-like network for learning the moduli dependence of the H matrix HNetMLP. For this network one can configure the features, e.g. powers of the moduli, which can be used as input into the fully connected neural network. Several standard hyperparameters, e.g. the layer size, activation function, can be chosen. It is beneficial to suppress almost vanishing components by multiplying them with a learnable sigmoid factor, i.e. α · sigmoid(β). A schematic overview of the H network can be found in Figure 2. These networks can be easily configured as part of the configuration files for the included machine learning script.
In summary, training of the network parameters is done using the following schematic procedure which averages the loss over several moduli values in each step: in the desired moduli range. 2. Apply the neural network with the current parameters θ to predict the matrices H θ (ψ) for the selected moduli.
3. For each ψ (i) , sample N z points on the variety z (i,j) Nz j=1 as outlined in section 2.2, together with their Monte Carlo weights w(z (i,j) ). figure 1, compute the ratio η for each of the points and their respective moduli. Note that this ratio depends on the current approximation to the metric parametrized by H, and thus introduces the dependence on the model parameters θ. (10), averaged over the N ψ moduli values:

Evaluate the loss of equation
6. Compute the gradients of L with respect to the network parameters θ via automatic differentiation. Use these to update the network parameters by gradient descent.  (8) given the moduli ψ as input. The hyperparameters include fixed powers p that the moduli are raised to in order to construct features which are passed to dense linear layers.
7. Repeat by returning to the first step, until some stopping criterion is met (e.g. a fixed time elapsed, target loss reached, etc.).
Note that H parametrizes the metric over the whole variety, and thus the network itself does not take points where the metric is evaluated as input. A detailed application of these steps to the Dwork quintic can be found in section 3.1.

Applications
Finally, we showcase a few applications of CYJAX. The script associated to these experiments can be found in our code repository, as well as notebooks explaining these and the procedures mentioned above.

Moduli dependent machine learning
We discuss two examples of moduli dependent machine learning. Readers primarily interested in the results may skip over some of the more detailed code examples.

Dwork quintic with different basis resolutions
To illustrate the explicit use of our package, we showcase here the steps to set up a neural network which learns to approximate the metric for the Dwork quintic. Note that throughout, we have to explicitly manage a random key. Given this state variable, pseud-random number generation is deterministic. In order to force careful consideration of this, and to facilitate reliable randomness in parallel computations, JAX requires explicit management of the random key. At the beginning of a program, it is initialized by choosing some random seed, e.g. depending on the system time or set by hand Then, the key can be split indefinitely, and once a key has been used to generate some random values it must be discarded. Firstly, we set up the problem by choosing the parametrized family of varieties and a monomial basis with respect to which we try to learn H. The aim of our network is to learn a map ψ → H such that the corresponding algebraic metric is close to Ricci flat. For illustration, we show a simple network which only depends on the absolute value of ψ.
# Next, we define the neural network as a flax module. We will optimize its # parameters to approximate the dependence of H on the complex moduli. With the neural network defined, we can now instantiate it according to the setup we have chosen above. This is done below, followed by an example of how the model parameters are initialized and how they are used to evaluate the H matrix. Let us now turn to the loss which is included to highlight the flexibility to change the optimisation objective. Here, for convenience, sampling is integrated into the loss function. First, we define a loss for a given fixed moduli value. Then we evaluate this loss over a batch of multiple moduli values which gives the loss we use per training step.  Figure 3 shows the results of training 3 for the Dwork family of equation (3) using a network architecture as in Figure 2 with two hidden layers of size 4096, feature-powers p = [1,2,3,4,5], and for degrees k = 4, 5, 6. Training was done for values in the square given by |Re[ψ]|, |Im[ψ]| < 10. As expected, we find improved accuracies for higher resolutions.  The shown loss during training is a running average over 100 training steps. Each training gradient descent step consists of 4 uniformly drawn samples of ψ and, for each, 500 points on the variety. The Adam optimizer with exponentially decaying learning rate starting at 10 −3 was used. The σ-accuracy evaluated during training is the average over 10 initially drawn and fixed moduli values. All σ-accuracies were evaluated using 1000 independently drawn points on the variety. The edges of the final accuracy plots show values marginalized to only one real component by taking the mean, the maximum, and the minimum over the respectively other real component.

A two-moduli example
CYJAX presently allows for the study of single defining equations with multiple moduli, in which case the input to the neural networks includes multiple moduli. To illustrate this, we perform training for two moduli parameters with the training routine essentially as above 4 . We consider the quintic defined by The moduli parameters are again uniformly sampled with real and imaginary parts ranging from −10 to 10. The training behavior and the final accuracy are summarised in Figure 4 for degree k = 4. This example demonstrates that training is still converging, but the accuracy seems to be worse than in the one modulus case. This is not surprising as the problem becomes more difficult the more moduli parameters we consider. A systematic analysis and a more elaborate hyper-parameter scan for improved networks is beyond the scope of these notes.

Comparison with Donaldson's algorithm for the Dwork quintic
When evaluating neural network metrics, it is a good sanity check to compare with other methods for obtaining metrics. One such method is to construct metrics using Donaldson's algorithm. The CYJAX package implements methods to compute these and a notebook showing how to run it can be found in the documentation. Here we show an example of results obtained for evaluating it for the Dwork quintic at different points in moduli space at k = 6. The results are shown in Figure 5. We observe that the qualitative behavior of the accuracies resemble the ones obtained with our neural networks as shown in Figure 3. Averaged over the selected range of moduli values, the σ-accuracy measure achieved with machine learning is about 4.4 times smaller (i.e. better) than the one with Donaldson's algorithm. Note, however, that it is well known that Donaldson's algorithm does not yield the optimal metric at each fixed degree k [8].

Outlook
As described in the beginning, these notes are intended to describe only the initial release of CYJAX. There are many ways we intend it to be utilized. Below are some possible directions to extend it: • One natural step for extension is to include other varieties, where complete intersection Calabi-Yau manifolds [20,22] and the Kreuzer-Skarke list [23] are natural candidates. As more manifolds become available, it will be natural to connect with other packages focusing on different aspects such as CYTOOLS [24].
• Another natural direction is to incorporate more general networks such as learning the metric directly or using different ansätze for the Kähler potential.
• Many components of CYJAX can be reused for the implementation of more general metrics, such as metrics with SU (3) structure as demonstrated in [9].  : Training behavior and obtained accuracy for our two moduli example of the quintic as defined in Equation (12). Training hyperparameters are chosen as in Figure 3, except with two input moduli and powers from 1 to 6 as well as 10 moduli values per training batch. The marginalization for the heatmaps was done by taking the average over respectively all other real moduli parameters. • We hope that the inherently auto-differentiable structure of the metric enables the investigation of other geometric objects which appear for phenomenologically relevant properties. For instance our metric is differentiable with respect to the moduli which is inherently necessary to link with studies of moduli stabilization.
Apart from these methodological extensions, there is a natural quest for achieving networks which result in high-accuracy metrics. As illustrated in our two moduli example (cf. Section 3.1.2), there is room to improve upon our current illustrative networks to obtain better approximations to the metric. We hope to return to many of these aspects in the not too distant future.