I just shipped a v1.0 rewrite. The new version adds new curvature operators (Generalized Gauss-Newton, empirical Fisher), and new algorithms (Hutchinson + Hutch++ trace estimation, spectral density via Stochastic Lanczos Quadrature). It also has a fused Triton/torch.compile cross-entropy Hessian-vector kernel for foundation-model-scale vocabularies (where standard implementations blow up). More importantly it adds a lot of numerical analysis validating the operators: closed-form correctness on linear/logistic regression where the Hessian is known analytically, and cross-library tests against curvlinops to catch any regressions.
https://github.com/noahgolmant/pytorch-hessian-eigenthings
I'm hoping to use it for some follow-up analysis. For example right now I'm looking at inter-agreement between various optimizer updates (Muon, K-FAC, Natural Gradient Descent) on Pythia checkpoints.
Very open to suggestions or requests from anyone who's been working in this space. I've been out of the field for a while, so pointers to recent work I should be aware of are very welcome.
is there a similar effect where transferring the Hessian speeds up knowledge distillation even faster?
Suppose one has a candidate alternative model architecture, how can one estimate the amount of compute needed for the knowledge distillation to a student model?
Consider for example the following model: each token (or character or bit) corresponds to a matrix (or a multivector), and a sequence of tokens corresponds to the matrix product (geometric product) of the appended tokens in the same order. The partition function / relative likelihood is taken as exp(-||Product(M_i)||) where ||matrix/multivector|| is the positive-definite squared norm of the matrix or multivector (basically sum of the squares of the components).
To get P(nextToken | productOfPreviousTokens) you calculate: P(productOfPreviousTokens * nextToken)/P(productOfPreviousTokens)
How does one calculate the expected number of forward inferences of the teacher network, and same number of gradient descents on the student network before its performance plateaus given their parameter sizes? How does this expected number of required forward inferences scale with or without adding a Hessian loss term?