Ruikun Li1, Jiazhen Liu2, Huandong Wang2*, Qingmin Liao1, Yong Li2
1 Shenzhen International Graduate School, Tsinghua University,
2 Department of Electronic Engineering, BNRist, Tsinghua University
* Corresponding Author (wanghuandong@tsinghua.edu.cn)
WeightFlow models the neural network weights as a graph and employs a graph neural differential equation to learn the continuous dynamics of this weight graph. The framework consists of two main parts:
1. Backbone ($\theta_t$): A backbone network with parameters $\theta_t$ models the static probability distribution at time $t$ using an autoregressive factorization:
2. Hypernetwork ($g_{\phi}$): A graph hypernetwork $g_{\phi}$ then models the continuous evolution of these weights $\theta_t$ as a Controlled Differential Equation (CDE):
We empirically evaluate WeightFlow on a diverse set of simulated and real-world stochastic dynamics, demonstrating its superior performance and robustness.
We first benchmarked WeightFlow against several state-of-the-art baselines on five discrete stochastic systems. As shown in Table 1, WeightFlow significantly outperforms all baselines, improving the Wasserstein (W) and Jensen-Shannon (JSD) distances by 32.04% and 53.99% on average, respectively.
| Model | Epidemic | Toggle Switch | Signalling Cascade1 | Signalling Cascade2 | Ecological Evolution | |||||
|---|---|---|---|---|---|---|---|---|---|---|
| $\mathcal{W} \downarrow$ | $JSD \downarrow$ | $\mathcal{W} \downarrow$ | $JSD \downarrow$ | $\mathcal{W} \downarrow$ | $JSD \downarrow$ | $\mathcal{W} \downarrow$ | $JSD \downarrow$ | $\mathcal{W} \downarrow$ | $JSD \downarrow$ | |
| Latent SDE | 3.14±0.25 | 4.22±0.26 | 2.34±0.15 | 1.27±0.12 | 3.04±0.17 | 0.85±0.14 | 3.59±0.13 | 1.02±0.06 | 8.04±0.33 | 3.52±0.23 |
| Neural MJP | 1.88±0.14 | 1.61±0.14 | 2.13±0.26 | 0.94±0.14 | 1.69±0.15 | 0.30±0.04 | 1.68±0.11 | 0.36±0.01 | 1.68±0.18 | 0.51±0.03 |
| T-IB | 2.62±0.17 | 3.52±0.29 | 1.59±0.20 | 0.88±0.11 | 1.66±0.16 | 0.32±0.04 | 2.16±0.17 | 0.40±0.03 | 2.17±0.24 | 0.56±0.06 |
| NLSB | 3.27±0.28 | 1.65±0.14 | 2.97±0.30 | 1.32±0.20 | 1.50±0.10 | 0.39±0.05 | 1.83±0.15 | 0.48±0.05 | 3.09±0.26 | 2.80±0.32 |
| DeepRUOT | 1.78±0.13 | 1.08±0.09 | 1.37±0.17 | 0.77±0.05 | 0.52±0.02 | 0.07±0.00 | 0.51±0.01 | 0.08±0.00 | 3.27±0.31 | 2.47±0.36 |
| WeightFlow (Ours) | 1.10±0.14 | 0.34±0.01 | 0.82±0.07 | 0.33±0.02 | 0.48±0.03 | 0.04±0.00 | 0.49±0.07 | 0.06±0.01 | 0.51±0.07 | 0.12±0.02 |
In the ecological evolution system (visualized below), a 2D genetic phenotype (Locus 1, Locus 2) evolves towards a global peak on a fitness landscape. WeightFlow accurately predicts the distribution's evolution, capturing both macroscopic landscape shifts and fine-grained local dynamics.
We also evaluated WeightFlow on high-dimensional, continuous-space, single-cell differentiation datasets. The visualization of the pancreatic β-cell differentiation path shows our model's predictions. WeightFlow is significantly more accurate for higher-order moments like skewness and kurtosis, reproducing fine-grained distribution structures.
| Model | $\beta$-cell | Embryoid | ||
|---|---|---|---|---|
| $\mathcal{W} \downarrow$ | $MMD \downarrow$ | $\mathcal{W} \downarrow$ | $MMD \downarrow$ | |
| NLSB | 11.18±0.22 | 0.07±0.01 | 14.39±0.40 | 0.10±0.03 |
| RUOT | 10.99±0.20 | 0.06±0.01 | 14.71±0.49 | 0.15±0.03 |
| WeightFlow (Ours) | 9.73±0.27 | 0.02±0.01 | 14.18±0.43 | 0.03±0.01 |
We performed ablation studies to validate key design choices of WeightFlow.
We analyzed WeightFlow's sensitivity to various hyperparameters, demonstrating its robustness.
WeightFlow is designed to be scalable and efficient. The backbone's size is independent of the system's dimension $d$, with $O(L)$ space complexity (where $L$ is states per dimension) and $O(d)$ inference time. The hypernetwork's complexity $O(N_{nodes}^2)$ is also independent of $d$. This design effectively avoids the curse of dimensionality.
If you find our work useful for your research, please consider citing: