jqmc package#
Submodules#
jqmc.atomic_orbital module#
Atomic Orbitals module.
Module containing classes and methods related to Atomic Orbitals
- class jqmc.atomic_orbital.AOs_cart_data(structure_data=<factory>, nucleus_index=<factory>, num_ao=0, num_ao_prim=0, orbital_indices=<factory>, exponents=<factory>, coefficients=<factory>, angular_momentums=<factory>, polynominal_order_x=<factory>, polynominal_order_y=<factory>, polynominal_order_z=<factory>)#
Bases:
objectAtomic orbital definitions in Cartesian form.
Stores contracted Gaussian basis data used to evaluate atomic orbitals (AOs) on a grid. The angular part is represented by Cartesian polynomials (\(x^{n_x} y^{n_y} z^{n_z}\)) with \(n_x + n_y + n_z = l\) for each AO.
- Variables:
structure_data (
Structure_data) – Molecular structure and atomic positions used to place AOs.nucleus_index (
list[int] | tuple[int]) – AO -> atom index mapping (len == num_ao).num_ao (
int) – Number of contracted AOs.num_ao_prim (
int) – Number of primitive Gaussians.orbital_indices (
list[int] | tuple[int]) – For each primitive, the parent AO index (len == num_ao_prim).exponents (
list[float] | tuple[float]) – Gaussian exponents for primitives (len == num_ao_prim).coefficients (
list[float] | tuple[float]) – Contraction coefficients per primitive (len == num_ao_prim).angular_momentums (
list[int] | tuple[int]) – Angular momentum quantum numberslper AO (len == num_ao).polynominal_order_x (
list[int] | tuple[int]) – Cartesian powern_xfor each AO (len == num_ao).polynominal_order_y (
list[int] | tuple[int]) – Cartesian powern_yfor each AO (len == num_ao).polynominal_order_z (
list[int] | tuple[int]) – Cartesian powern_zfor each AO (len == num_ao).
- Parameters:
structure_data (Structure_data)
nucleus_index (list[int] | tuple[int])
num_ao (int)
num_ao_prim (int)
orbital_indices (list[int] | tuple[int])
exponents (list[float] | tuple[float])
coefficients (list[float] | tuple[float])
angular_momentums (list[int] | tuple[int])
polynominal_order_x (list[int] | tuple[int])
polynominal_order_y (list[int] | tuple[int])
polynominal_order_z (list[int] | tuple[int])
Examples
Minimal hydrogen dimer (bohr) with all-electron cc-pVTZ (Gaussian format) in Cartesian form:
from jqmc.structure import Structure_data from jqmc.atomic_orbital import AOs_cart_data structure = Structure_data( positions=[[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]], pbc_flag=False, atomic_numbers=[1, 1], element_symbols=["H", "H"], atomic_labels=["H1", "H2"], ) # cc-pVTZ primitives duplicated per Cartesian component; counts: # per atom -> 15 AOs, 19 primitives; for two atoms num_ao=30; num_ao_prim=38 exponents = [ 0.3258, 33.87, 5.095, 1.159, 0.3258, 0.1027, 0.1027, 1.407, 1.407, 1.407, 0.388, 0.388, 0.388, 1.057, 1.057, 1.057, 1.057, 1.057, 1.057, 0.3258, 33.87, 5.095, 1.159, 0.3258, 0.1027, 0.1027, 1.407, 1.407, 1.407, 0.388, 0.388, 0.388, 1.057, 1.057, 1.057, 1.057, 1.057, 1.057, ] coefficients = [ 1.0, 0.006068, 0.045308, 0.202822, 0.503903, 0.383421, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.006068, 0.045308, 0.202822, 0.503903, 0.383421, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ] orbital_indices = [ 0, # S shell (uncontracted) 1, 1, 1, 1, 1, # S shell (contracted) 2, # S shell (uncontracted) 3, 4, 5, # P shell (uncontracted) 6, 7, 8, # P shell (uncontracted) 9, 10, 11, 12, 13, 14, # D shell (uncontracted) 15, # S shell (uncontracted) 16, 16, 16, 16, 16, # S shell (contracted) 17, # S shell (uncontracted) 18, 19, 20, # P shell (uncontracted) 21, 22, 23, # P shell (uncontracted) 24, 25, 26, 27, 28, 29, # D shell (uncontracted) ] nucleus_index = [ *([0] * 15), *([1] * 15), ] angular_momentums = [ 0, 0, 0, # S shells (1 orbital each) 1, 1, 1, 1, 1, 1, # two P shells (3 orbitals each) 2, 2, 2, 2, 2, 2, # one D shell (6 orbitals each) 0, 0, 0, # S shells (1 orbital each) 1, 1, 1, 1, 1, 1, # two P shells (3 orbitals each) 2, 2, 2, 2, 2, 2, # one D shell (6 orbitals each) ] polynominal_order_x = [ 0, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 0, ] polynominal_order_y = [ 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 2, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 2, 1, 0, ] polynominal_order_z = [ 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 2, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 2, ] aos = AOs_cart_data( structure_data=structure, nucleus_index=nucleus_index, num_ao=num_ao, num_ao_prim=num_ao_prim, orbital_indices=orbital_indices, exponents=exponents, coefficients=coefficients, angular_momentums=angular_momentums, polynominal_order_x=polynominal_order_x, polynominal_order_y=polynominal_order_y, polynominal_order_z=polynominal_order_z, ) aos.sanity_check()
- angular_momentums: list[int] | tuple[int]#
Angular momentum quantum numbers
lper AO (len == num_ao).
- coefficients: list[float] | tuple[float]#
Contraction coefficients per primitive (
len == num_ao_prim).
- exponents: list[float] | tuple[float]#
Gaussian exponents for primitives (
len == num_ao_prim).
- nucleus_index: list[int] | tuple[int]#
AO -> atom index mapping (
len == num_ao).
- num_ao: int = 0#
Number of contracted AOs.
- num_ao_prim: int = 0#
Number of primitive Gaussians.
- orbital_indices: list[int] | tuple[int]#
For each primitive, the parent AO index (
len == num_ao_prim).
- polynominal_order_x: list[int] | tuple[int]#
Cartesian power
n_xfor each AO (len == num_ao).
- polynominal_order_y: list[int] | tuple[int]#
Cartesian power
n_yfor each AO (len == num_ao).
- polynominal_order_z: list[int] | tuple[int]#
Cartesian power
n_zfor each AO (len == num_ao).
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Validate AO array shapes and basic types.
Ensures that array lengths match declared counts (
num_ao,num_ao_prim) and that inputs are provided as lists/tuples (ints for counts). Call this after constructingAOs_cart_dataand before AO evaluation routines.- Raises:
ValueError – When any of the following holds: -
len(nucleus_index) != num_ao-len(unique(orbital_indices)) != num_ao-len(exponents) != num_ao_primorlen(coefficients) != num_ao_prim-len(angular_momentums) != num_ao-len(polynominal_order_{x,y,z}) != num_ao- any attribute has an unexpected Python type (e.g., non-list/tuple arrays, non-int counts)- Return type:
None
- structure_data: Structure_data#
Molecular structure and atomic positions used to place AOs.
- class jqmc.atomic_orbital.AOs_sphe_data(structure_data=<factory>, nucleus_index=<factory>, num_ao=0, num_ao_prim=0, orbital_indices=<factory>, exponents=<factory>, coefficients=<factory>, angular_momentums=<factory>, magnetic_quantum_numbers=<factory>)#
Bases:
objectAtomic orbital definitions in real spherical-harmonic form.
Stores contracted Gaussian basis data for atomic orbitals (AOs) whose angular part is a real spherical harmonic \(Y_{l}^{m}\) with \(m \in \{-l,\dots,+l\}\).
- Variables:
structure_data (
Structure_data) – Molecular structure and atomic positions used to place AOs.nucleus_index (
list[int] | tuple[int]) – AO -> atom index mapping (len == num_ao).num_ao (
int) – Number of contracted AOs.num_ao_prim (
int) – Number of primitive Gaussians.orbital_indices (
list[int] | tuple[int]) – For each primitive, the parent AO index (len == num_ao_prim).exponents (
list[float] | tuple[float]) – Gaussian exponents for primitives (len == num_ao_prim).coefficients (
list[float] | tuple[float]) – Contraction coefficients per primitive (len == num_ao_prim).angular_momentums (
list[int] | tuple[int]) – Angular momentum quantum numberslper AO (len == num_ao).magnetic_quantum_numbers (
list[int] | tuple[int]) – Magnetic quantum numbersmper AO (len == num_ao), satisfying-l <= m <= l.
- Parameters:
structure_data (Structure_data)
nucleus_index (list[int] | tuple[int])
num_ao (int)
num_ao_prim (int)
orbital_indices (list[int] | tuple[int])
exponents (list[float] | tuple[float])
coefficients (list[float] | tuple[float])
angular_momentums (list[int] | tuple[int])
magnetic_quantum_numbers (list[int] | tuple[int])
Examples
Hydrogen dimer (bohr) with all-electron cc-pVTZ (Gaussian format), real spherical harmonics:
from jqmc.structure import Structure_data from jqmc.atomic_orbital import AOs_sphe_data structure = Structure_data( positions=[[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]], pbc_flag=False, atomic_numbers=[1, 1], element_symbols=["H", "H"], atomic_labels=["H1", "H2"], ) # Per atom: 14 AOs (3 S, 2×P shells -> 6, 1×D shell -> 5); 18 primitives. # Two atoms -> num_ao=28, num_ao_prim=36. exponents = [ # atom 1 0.3258, 33.87, 5.095, 1.159, 0.3258, 0.1027, 0.1027, 1.407, 1.407, 1.407, 0.388, 0.388, 0.388, 1.057, 1.057, 1.057, 1.057, 1.057, # atom 2 (same order) 0.3258, 33.87, 5.095, 1.159, 0.3258, 0.1027, 0.1027, 1.407, 1.407, 1.407, 0.388, 0.388, 0.388, 1.057, 1.057, 1.057, 1.057, 1.057, ] coefficients = [ # atom 1 1.0, 0.006068, 0.045308, 0.202822, 0.503903, 0.383421, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, # atom 2 (same order) 1.0, 0.006068, 0.045308, 0.202822, 0.503903, 0.383421, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ] orbital_indices = [ # atom 1 AOs 0-13 0, 1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, # atom 2 AOs 14-27 14, 15, 15, 15, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, ] nucleus_index = [ *([0] * 14), *([1] * 14), ] angular_momentums = [ # atom 1 0, 0, 0, # S shells 1, 1, 1, 1, 1, 1, # two P shells (3 each) 2, 2, 2, 2, 2, # one D shell (5) # atom 2 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, ] magnetic_quantum_numbers = [ # atom 1 (S: m=0; P: -1,0,1; D: -2..2) 0, 0, 0, -1, 0, 1, -1, 0, 1, -2, -1, 0, 1, 2, # atom 2 0, 0, 0, -1, 0, 1, -1, 0, 1, -2, -1, 0, 1, 2, ] aos = AOs_sphe_data( structure_data=structure, nucleus_index=nucleus_index, num_ao=len(angular_momentums), num_ao_prim=len(exponents), orbital_indices=orbital_indices, exponents=exponents, coefficients=coefficients, angular_momentums=angular_momentums, magnetic_quantum_numbers=magnetic_quantum_numbers, ) aos.sanity_check()
- angular_momentums: list[int] | tuple[int]#
Angular momentum quantum numbers
lper AO (len == num_ao).
- coefficients: list[float] | tuple[float]#
Contraction coefficients per primitive (
len == num_ao_prim).
- exponents: list[float] | tuple[float]#
Gaussian exponents for primitives (
len == num_ao_prim).
- magnetic_quantum_numbers: list[int] | tuple[int]#
Magnetic quantum numbers
mper AO (len == num_ao;-l <= m <= l).
- nucleus_index: list[int] | tuple[int]#
AO -> atom index mapping (
len == num_ao).
- num_ao: int = 0#
Number of contracted AOs.
- num_ao_prim: int = 0#
Number of primitive Gaussians.
- orbital_indices: list[int] | tuple[int]#
For each primitive, the parent AO index (
len == num_ao_prim).
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Validate AO array shapes and basic types.
Ensures that array lengths match declared counts (
num_ao,num_ao_prim) and that inputs are provided as lists/tuples (ints for counts). Call this after constructingAOs_sphe_dataand before AO evaluation routines.- Raises:
ValueError – When any of the following holds: -
len(nucleus_index) != num_ao-len(unique(orbital_indices)) != num_ao-len(exponents) != num_ao_primorlen(coefficients) != num_ao_prim-len(angular_momentums) != num_ao-len(magnetic_quantum_numbers) != num_ao- any attribute has an unexpected Python type (e.g., non-list/tuple arrays, non-int counts)- Return type:
None
- structure_data: Structure_data#
Molecular structure and atomic positions used to place AOs.
- jqmc.atomic_orbital.compute_AOs(aos_data, r_carts)#
Evaluate contracted atomic orbitals (AOs) at electron coordinates.
Dispatches to Cartesian or real-spherical backends and returns float64 JAX arrays (ensure
jax_enable_x64=True). Callaos_data.sanity_check()before use.- Parameters:
aos_data (AOs_sphe_data | AOs_cart_data) –
AOs_cart_dataorAOs_sphe_datadescribing centers, primitive parameters, angular data, and contraction mapping.r_carts (
jax.Array) – Electron Cartesian coordinates, shape(N_e, 3)in Bohr. Casts tofloat64internally viajnp.asarray.
- Returns:
AO values, shape
(num_ao, N_e).- Return type:
jax.Array
- Raises:
NotImplementedError – If
aos_datais neither Cartesian nor spherical.
- jqmc.atomic_orbital.compute_AOs_grad(aos_data, r_carts)#
Return analytic Cartesian gradients of contracted atomic orbitals.
Public gradient API used for drift vectors and kinetic terms. Dispatches to Cartesian or real-spherical backends; returns float64 JAX arrays (ensure
jax_enable_x64=True).- Parameters:
aos_data (AOs_sphe_data | AOs_cart_data) –
AOs_cart_dataorAOs_sphe_datadescribing primitive parameters, angular info, contraction mapping, and centers (runsanity_check()beforehand).r_carts (
jax.Array) – Electron Cartesian coordinates, shape(N_e, 3)(Bohr). Casts tofloat64internally viajnp.asarray.
- Returns:
Gradients w.r.t. x, y, z, each of shape
(num_ao, N_e). Order is(gx, gy, gz).- Return type:
tuple[jax.Array, jax.Array, jax.Array]
- Raises:
NotImplementedError – If
aos_datais neither Cartesian nor spherical.
- jqmc.atomic_orbital.compute_AOs_laplacian(aos_data, r_carts)#
Return analytic Laplacians of contracted atomic orbitals.
Dispatches to Cartesian or real-spherical implementations; returns float64 JAX arrays (ensure
jax_enable_x64=True).- Parameters:
aos_data (AOs_sphe_data | AOs_cart_data) –
AOs_cart_dataorAOs_sphe_datadescribing centers, primitive exponents/coefficients, angular data, and contraction mapping (runsanity_check()beforehand).r_carts (
jax.Array) – Electron Cartesian coordinates, shape(N_e, 3)(Bohr). Casts tofloat64internally viajnp.asarray.
- Returns:
Laplacians of all contracted AOs, shape
(num_ao, N_e).- Return type:
jax.Array
- Raises:
NotImplementedError – If
aos_datais not Cartesian or spherical.
jqmc.coulomb_potential module#
Effective core potential module.
Module containing classes and methods related to Effective core potential and bare Coulomb potentials
- class jqmc.coulomb_potential.Coulomb_potential_data(structure_data=<factory>, ecp_flag=False, z_cores=<factory>, max_ang_mom_plus_1=<factory>, num_ecps=0, ang_moms=<factory>, nucleus_index=<factory>, exponents=<factory>, coefficients=<factory>, powers=<factory>)#
Bases:
objectContainer for bare Coulomb and effective core potential (ECP) parameters.
- Parameters:
structure_data (
Structure_data) – Underlying nuclear geometry and metadata.ecp_flag (
bool) – Whether ECPs are present. WhenTrue, all ECP arrays must be populated.z_cores (
list[float] | tuple[float]) – Core electrons removed per atom; lengthnatom.max_ang_mom_plus_1 (
list[int] | tuple[int]) –l_max + 1for each atom; lengthnatom.num_ecps (
int) – Total number of ECP projector terms across all atoms and angular momenta.ang_moms (
list[int] | tuple[int]) – Angular momentumlper ECP term; lengthnum_ecps.nucleus_index (
list[int] | tuple[int]) – Atom index per ECP term; lengthnum_ecps.exponents (
list[float] | tuple[float]) – Gaussian exponents per ECP term; lengthnum_ecps.coefficients (
list[float] | tuple[float]) – Prefactors per ECP term; lengthnum_ecps.powers (
list[int] | tuple[int]) – Polynomial powers per ECP term; lengthnum_ecps.
Notes
When
ecp_flagisFalse, all ECP-related sequences must be empty andnum_ecpsshould be 0.Arrays are stored as Python lists/tuples for pytrees; conversion to
jax.Arrayhappens in the compute kernels.
- ang_moms: list[int] | tuple[int]#
Angular momentum
lper ECP term (len = num_ecps).
- coefficients: list[float] | tuple[float]#
Prefactors per ECP term (len = num_ecps).
- ecp_flag: bool = False#
Whether ECP parameters are active.
- exponents: list[float] | tuple[float]#
Gaussian exponents per ECP term (len = num_ecps).
- max_ang_mom_plus_1: list[int] | tuple[int]#
l_max + 1per atom (len = natom).
- nucleus_index: list[int] | tuple[int]#
Atom index per ECP term (len = num_ecps).
- num_ecps: int = 0#
Total ECP projector terms across all atoms.
- powers: list[int] | tuple[int]#
Polynomial powers per ECP term (len = num_ecps).
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- structure_data: Structure_data#
Nuclear geometry and atom metadata.
- z_cores: list[float] | tuple[float]#
Core electrons removed per atom (len = natom).
- jqmc.coulomb_potential.compute_bare_coulomb_potential(coulomb_potential_data, r_up_carts, r_dn_carts)#
Compute bare Coulomb interaction (ion–ion, electron–ion, electron–electron).
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – Structure and charges (effective if ECPs present).r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.
- Returns:
Total bare Coulomb energy.
- Return type:
float
- jqmc.coulomb_potential.compute_bare_coulomb_potential_el_el(r_up_carts, r_dn_carts)#
Electron–electron Coulomb interaction energy.
- Parameters:
r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.
- Returns:
Electron–electron Coulomb energy.
- Return type:
float
- jqmc.coulomb_potential.compute_bare_coulomb_potential_el_ion(coulomb_potential_data, r_up_carts, r_dn_carts)#
Total electron–ion Coulomb interaction energy.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – Structure and charges (effective if ECPs present).r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.
- Returns:
Electron–ion Coulomb energy.
- Return type:
float
- jqmc.coulomb_potential.compute_bare_coulomb_potential_el_ion_element_wise(coulomb_potential_data, r_up_carts, r_dn_carts)#
Element-wise electron–ion Coulomb interactions.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – Structure and charges (effective if ECPs present).r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.
- Returns:
Element-wise ion–electron interactions for up spins and down spins (shape
(N_up,)and(N_dn,)).- Return type:
tuple[jax.Array, jax.Array]
- jqmc.coulomb_potential.compute_bare_coulomb_potential_ion_ion(coulomb_potential_data)#
Ion–ion Coulomb interaction energy.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – Structure and charges (effective if ECPs present).- Returns:
Ion–ion Coulomb energy.
- Return type:
float
- jqmc.coulomb_potential.compute_coulomb_potential(coulomb_potential_data, r_up_carts, r_dn_carts, RT, NN=1, Nv=6, wavefunction_data=None)#
Compute total Coulomb energy including bare and ECP terms.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – Structure, charges, and ECP parameters.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.RT (
jax.Array) – Rotation matrix applied to quadrature grid points (shape(3, 3)) for non-local ECP.NN (
int) – Number of nearest nuclei to include for each electron in the non-local term.Nv (
int) – Number of quadrature points on the sphere.wavefunction_data (
Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ECP ratios; required whenecp_flagis True.
- Returns:
Sum of bare Coulomb (ion–ion, electron–ion, electron–electron) and ECP (local + non-local) energies.
- Return type:
float
- jqmc.coulomb_potential.compute_discretized_bare_coulomb_potential_el_ion_element_wise(coulomb_potential_data, r_up_carts, r_dn_carts, alat)#
Element-wise electron–ion Coulomb interactions with distance floor
alat.- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – Structure and charges (effective if ECPs present).r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.alat (
float) – Minimum allowed distance to avoid divergence.
- Returns:
Element-wise ion–electron interactions for up spins and down spins (shape
(N_up,)and(N_dn,)).- Return type:
tuple[jax.Array, jax.Array]
- jqmc.coulomb_potential.compute_ecp_coulomb_potential(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, NN=1, Nv=6)#
Compute total ECP energy (local + non-local) for a configuration.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – ECP parameters and structure data.wavefunction_data (
Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.RT (
jax.Array) – Rotation matrix applied to quadrature grid points (shape(3, 3)).NN (
int) – Number of nearest nuclei to include for each electron in the non-local term.Nv (
int) – Number of quadrature points on the sphere.
- Returns:
Sum of local and non-local ECP contributions for the given geometry.
- Return type:
float
- jqmc.coulomb_potential.compute_ecp_local_parts_all_pairs(coulomb_potential_data, r_up_carts, r_dn_carts)#
Compute local ECP contribution over all nucleus–electron pairs.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – ECP parameters and structure data.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.
- Returns:
Total local ECP energy for the provided electron coordinates.
- Return type:
float
- jqmc.coulomb_potential.compute_ecp_non_local_part_all_pairs_jax_weights_grid_points(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, weights, grid_points, flag_determinant_only=0)#
Vectorized non-local ECP projection over all pairs with provided quadrature.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – ECP parameters and structure data.wavefunction_data (
Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.weights (
list[float]) – Quadrature weights for angular integration.grid_points (
npt.NDArray[np.float64]) – Quadrature grid points with shape(Nv, 3).flag_determinant_only (
int) – If 1, skip Jastrow in the wavefunction ratio; if 0, include it.
- Returns:
Mesh-displaced up-spin coordinates per configuration.
Mesh-displaced down-spin coordinates per configuration.
Non-local contributions for up-spin mesh points.
Non-local contributions for down-spin mesh points.
Scalar sum of all non-local contributions.
- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array, float]
- jqmc.coulomb_potential.compute_ecp_non_local_parts_all_pairs(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, Nv=6, flag_determinant_only=False)#
Compute non-local ECP contribution considering all nucleus–electron pairs.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – ECP parameters and structure data.wavefunction_data (
Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.RT (
jax.Array) – Rotation matrix applied to quadrature grid points (shape(3, 3)).Nv (
int) – Number of quadrature points on the sphere.flag_determinant_only (
bool) – If True, ignore Jastrow in the wavefunction ratio.
- Returns:
Mesh-displaced
r_up_cartsper configuration.Mesh-displaced
r_dn_cartsper configuration.Non-local ECP contributions per configuration (flattened).
Scalar sum of all non-local contributions.
- Return type:
tuple[list[jax.Array], list[jax.Array], jax.Array, float]
- jqmc.coulomb_potential.compute_ecp_non_local_parts_nearest_neighbors(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, NN=1, Nv=6, flag_determinant_only=False)#
Compute non-local ECP contribution with a nearest-neighbor cutoff.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – ECP parameters and structure data.wavefunction_data (
Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.RT (
jax.Array) – Rotation matrix applied to quadrature grid points (shape(3, 3)).NN (
int) – Number of nearest nuclei to include for each electron.Nv (
int) – Number of quadrature points on the sphere.flag_determinant_only (
bool) – If True, ignore Jastrow in the wavefunction ratio.
- Returns:
Mesh-displaced
r_up_cartsper configuration.Mesh-displaced
r_dn_cartsper configuration.Non-local ECP contributions per configuration (flattened).
Scalar sum of all non-local contributions.
- Return type:
tuple[list[jax.Array], list[jax.Array], jax.Array, float]
- jqmc.coulomb_potential.compute_ecp_non_local_parts_nearest_neighbors_fast_update(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, A_old_inv, NN=1, Nv=6, flag_determinant_only=False)#
Fast-update variant of non-local ECP contributions (nearest neighbors).
This variant reuses the inverse geminal matrix to compute determinant ratios and uses Jastrow ratios, avoiding full recomputation for each mesh point.
- Parameters:
coulomb_potential_data (
Coulomb_potential_data) – ECP parameters and structure data.wavefunction_data (
Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.r_up_carts (
jax.Array) – Up-spin electron Cartesian coordinates with shape(N_up, 3)andfloat64dtype.r_dn_carts (
jax.Array) – Down-spin electron Cartesian coordinates with shape(N_dn, 3)andfloat64dtype.RT (
jax.Array) – Rotation matrix applied to quadrature grid points (shape(3, 3)).A_old_inv (
jax.Array) – Inverse geminal matrix evaluated at(r_up_carts, r_dn_carts).NN (
int) – Number of nearest nuclei to include for each electron.Nv (
int) – Number of quadrature points on the sphere.flag_determinant_only (
bool) – If True, ignore Jastrow in the wavefunction ratio.
- Returns:
Mesh-displaced
r_up_cartsper configuration.Mesh-displaced
r_dn_cartsper configuration.Non-local ECP contributions per configuration (flattened).
Scalar sum of all non-local contributions.
- Return type:
tuple[list[jax.Array], list[jax.Array], jax.Array, float]
jqmc.determinant module#
Determinant module.
- class jqmc.determinant.Geminal_data(num_electron_up=0, num_electron_dn=0, orb_data_up_spin=<factory>, orb_data_dn_spin=<factory>, lambda_matrix=<factory>)#
Bases:
objectGeminal (AGP) parameters and orbital references.
- Parameters:
num_electron_up (
int) – Number of spin-up electrons.num_electron_dn (
int) – Number of spin-down electrons.orb_data_up_spin (
AOs_data | MOs_data) – Basis/orbitals for spin-up electrons.orb_data_dn_spin (
AOs_data | MOs_data) – Basis/orbitals for spin-down electrons.lambda_matrix (
npt.NDArray | jax.Array) – Geminal pairing matrix with shape(orb_num_up, orb_num_dn + num_electron_up - num_electron_dn).
Notes
For closed shells,
orb_num_up == orb_num_dnandlambda_matrixis square.For open shells, the right block encodes unpaired spin-up orbitals.
- accumulate_position_grad(grad_geminal)#
Aggregate position gradients from geminal-related structures.
- Parameters:
grad_geminal (Geminal_data)
- apply_block_update(block)#
Apply a single variational-parameter block update to this Geminal object.
This method is the Geminal-specific counterpart of
Wavefunction_data.apply_block_updates(). It receives a genericVariationalParameterBlockwhosevalueshave already been updated (typically byblock.apply_updateinside the SR/MCMC driver), and interprets that block according to the structure of the geminal (lambda) matrix.Responsibilities of this method are:
Map the block name (currently
"lambda_matrix") to the internal geminal parameters.Handle the splitting of a rectangular lambda matrix into paired and unpaired parts when needed.
Enforce Geminal-specific structural constraints, especially the symmetry conditions on the paired block of the lambda matrix.
All details about how the lambda parameters are stored and constrained live here (and in the surrounding
Geminal_dataclass), not inVariationalParameterBlockor in the optimizer. This keeps the SR/MCMC machinery and the block abstraction structure-agnostic: adding new Geminal parameters should only require updating the block construction inWavefunction_data.get_variational_blocksand adding the corresponding handling in this method.- Parameters:
block (VariationalParameterBlock)
- Return type:
- collect_param_grads(grad_geminal)#
Collect parameter gradients into a flat dict keyed by block name.
- Parameters:
grad_geminal (Geminal_data)
- Return type:
dict[str, object]
- property compute_orb_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#
Function for computing AOs or MOs.
The api method to compute AOs or MOs corresponding to instances stored in self.orb_data_up_spin and self.orb_data_dn_spin
- Returns:
The api method to compute AOs or MOs.
- Return type:
Callable
- Raises:
NotImplementedError – If the instances of orb_data_up_spin/orb_data_dn_spin are neither AOs_data/AOs_data nor MOs_data/MOs_data.
- property compute_orb_grad_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#
Function for computing AOs or MOs grads.
The api method to compute AOs or MOs grads corresponding to instances stored in self.orb_data_up_spin and self.orb_data_dn_spin.
- Returns:
The api method to compute AOs or MOs grads.
- Return type:
Callable
- Raises:
NotImplementedError – If the instances of orb_data_up_spin/orb_data_dn_spin are neither AOs_data/AOs_data nor MOs_data/MOs_data.
- property compute_orb_laplacian_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#
Function for computing AOs or MOs laplacians.
The api method to compute AOs or MOs laplacians corresponding to instances stored in self.orb_data_up_spin and self.orb_data_dn_spin.
- Returns:
The api method to compute AOs or MOs laplacians.
- Return type:
Callable
- Raises:
NotImplementedError – If the instances of orb_data_up_spin/orb_data_dn_spin are neither AOs_data/AOs_data nor MOs_data/MOs_data.
- classmethod convert_from_MOs_to_AOs(geminal_data)#
Convert MOs to AOs.
- Parameters:
geminal_data (Geminal_data)
- Return type:
- lambda_matrix: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array#
Geminal pairing matrix; see class notes for expected shape.
- num_electron_dn: int = 0#
Number of spin-down electrons.
- num_electron_up: int = 0#
Number of spin-up electrons.
- orb_data_dn_spin: AOs_sphe_data | AOs_cart_data | MOs_data#
Orbital data (AOs or MOs) for spin-down electrons.
- orb_data_up_spin: AOs_sphe_data | AOs_cart_data | MOs_data#
Orbital data (AOs or MOs) for spin-up electrons.
- property orb_num_dn: int#
orb_num_dn.
The number of atomic orbitals or molecular orbitals for down electrons, depending on the instance stored in the attribute orb_data_up.
- Returns:
The number of atomic orbitals or molecular orbitals for down electrons.
- Return type:
int
- Raises:
NotImplementedError – If the instance of orb_data_dn_spin is neither AOs_data nor MOs_data.
- property orb_num_up: int#
orb_num_up.
The number of atomic orbitals or molecular orbitals for up electrons, depending on the instance stored in the attribute orb_data_up.
- Returns:
The number of atomic orbitals or molecular orbitals for up electrons.
- Return type:
int
- Raises:
NotImplementedError – If the instance of orb_data_up_spin is neither AOs_data nor MOs_data.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- jqmc.determinant.compute_AS_regularization_factor(geminal_data, r_up_carts, r_dn_carts)#
Compute Attaccalite–Sorella regularization from electron coordinates.
- Parameters:
geminal_data (Geminal_data) – Geminal parameters and orbital references.
r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape
(N_up, 3).r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape
(N_dn, 3).
- Returns:
Scalar AS regularization factor.
- Return type:
jax.Array
- jqmc.determinant.compute_AS_regularization_factor_fast_update(geminal, geminal_inv)#
Compute Attaccalite–Sorella regularization via fast update.
- Parameters:
geminal (ndarray[tuple[Any, ...], dtype[float64]]) – Geminal matrix with shape
(N_up, N_up).geminal_inv (ndarray[tuple[Any, ...], dtype[float64]]) – Inverse geminal matrix with shape
(N_up, N_up).
- Returns:
Scalar AS regularization factor.
- Return type:
jax.Array
- jqmc.determinant.compute_det_geminal_all_elements(geminal_data, r_up_carts, r_dn_carts)#
Compute $det G$ for the geminal matrix.
- Parameters:
geminal_data (Geminal_data) – Geminal parameters and orbital references.
r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape
(N_up, 3).r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape
(N_dn, 3).
- Returns:
Scalar determinant of the geminal matrix.
- Return type:
float
- jqmc.determinant.compute_geminal_all_elements(geminal_data, r_up_carts, r_dn_carts)#
Compute geminal matrix $G$ for all electron pairs.
- Parameters:
geminal_data (Geminal_data) – Geminal parameters and orbital references.
r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape
(N_up, 3).r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape
(N_dn, 3).
- Returns:
Geminal matrix with shape
(N_up, N_up)combining paired and unpaired blocks.- Return type:
jax.Array
- jqmc.determinant.compute_geminal_dn_one_column_elements(geminal_data, r_up_carts, r_dn_cart)#
Single column of the geminal matrix for one spin-down electron.
- Parameters:
geminal_data – Geminal parameters and orbital references.
r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape
(N_up, 3).r_dn_cart (Array) – Cartesian coordinate for one spin-down electron with shape
(3,)or(1, 3).
- Returns:
Column vector for the paired block with shape
(N_up,).- Return type:
jax.Array
- jqmc.determinant.compute_geminal_up_one_row_elements(geminal_data, r_up_cart, r_dn_carts)#
Single row of the geminal matrix for one spin-up electron.
- Parameters:
geminal_data – Geminal parameters and orbital references.
r_up_cart (Array) – Cartesian coordinate for one spin-up electron with shape
(3,)or(1, 3).r_dn_carts (Array) – Cartesian coordinates for all spin-down electrons with shape
(N_dn, 3).
- Returns:
Row vector with shape
(N_dn + N_unpaired,).- Return type:
jax.Array
- jqmc.determinant.compute_grads_and_laplacian_ln_Det(geminal_data, r_up_carts, r_dn_carts)#
Gradients and Laplacians of $lndet G$ for each electron.
- Parameters:
geminal_data (Geminal_data) – Geminal parameters and orbital references.
r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape
(N_up, 3).r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape
(N_dn, 3).
- Returns:
Gradients for spin-up electrons with shape
(N_up, 3).Gradients for spin-down electrons with shape
(N_dn, 3).Laplacians for spin-up electrons with shape
(N_up,).Laplacians for spin-down electrons with shape
(N_dn,).
- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- jqmc.determinant.compute_grads_and_laplacian_ln_Det_fast(geminal_data, r_up_carts, r_dn_carts, geminal_inverse)#
Gradients and Laplacians of ln det G using a precomputed inverse.
- Parameters:
geminal_data (Geminal_data) – Geminal parameters and orbital references.
r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape
(N_up, 3).r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape
(N_dn, 3).geminal_inverse (Array) – Precomputed inverse of the geminal matrix
G.
- Returns:
Gradients (up/down) and Laplacians (up/down) of ln det G per electron.
- Return type:
tuple[Array, Array, Array, Array]
- jqmc.determinant.compute_ratio_determinant_part(geminal_data, A_old_inv, old_r_up_carts, old_r_dn_carts, new_r_up_carts_arr, new_r_dn_carts_arr)#
Determinant ratio $det G(mathbf r’)/det G(mathbf r)$ for batched moves.
- Parameters:
geminal_data (Geminal_data) – Geminal parameters and orbital references.
A_old_inv (Array) – Inverse geminal matrix for the reference configuration with shape
(N_up, N_up).old_r_up_carts (Array) – Original spin-up electron coordinates with shape
(N_up, 3).old_r_dn_carts (Array) – Original spin-down electron coordinates with shape
(N_dn, 3).new_r_up_carts_arr (Array) – Proposed spin-up coordinates per grid with shape
(N_grid, N_up, 3).new_r_dn_carts_arr (Array) – Proposed spin-down coordinates per grid with shape
(N_grid, N_dn, 3).
- Returns:
Determinant ratios per grid with shape
(N_grid,).- Return type:
jax.Array
jqmc.diff_mask module#
Utilities for masking PyTree leaves by derivative type.
This module centralizes the logic for selectively stopping gradients on parameter- or coordinate-like leaves so we can reuse the same dataclasses for all derivative modes (params, positions, or none) without proliferating specialized subclasses.
- class jqmc.diff_mask.DiffMask(params=True, coords=True)#
Bases:
objectSimple mask controlling which leaf types remain differentiable.
- Parameters:
params (bool)
coords (bool)
- coords: bool = True#
- params: bool = True#
- jqmc.diff_mask.apply_diff_mask(obj, mask)#
Return a copy of
objwith gradients stopped according tomask.The function recurses through dataclass fields (including Flax
structdataclasses), honoring optionaldiff_tagmetadata on fields when present. Otherwise, a small set of field-name heuristics is used to decide whether a leaf should be treated as a parameter (“param”) or coordinate (“coord”). Lists/tuples are traversed elementwise; everything else is returned unchanged.- Parameters:
obj (Any)
mask (DiffMask)
- Return type:
Any
jqmc.function_collections module#
Collections of useful functions.
jqmc.hamiltonians module#
Hamiltonian module.
- class jqmc.hamiltonians.Hamiltonian_data(structure_data=<factory>, coulomb_potential_data=<factory>, wavefunction_data=<factory>)#
Bases:
objectHamiltonian dataclass.
The class contains data for computing Kinetic and Potential energy terms.
- Parameters:
structure_data (
Structure_data) – an instance of Structure_datacoulomb_data (
Coulomb_data) – an instance of Coulomb_datawavefunction_data (
Wavefunction_data) – an instance of Wavefunction_datacoulomb_potential_data (Coulomb_potential_data)
Notes
Heres are the differentiable arguments, i.e., pytree_node = True This information is a little bit tricky in terms of a principle of the object-oriented programming, ‘Don’t ask, but tell’ (i.e., the Hamiltonian_data knows the details of the other classes too much), but there is no other choice to dynamically switch on and off pytree_nodes depending on optimized variational parameters chosen by a user because @dataclass is statistically generated.
- WF parameters related:
lambda in wavefunction_data.geminal_data (determinant.py)
jastrow_2b_param in wavefunction_data.jastrow_data.jastrow_two_body_data (jastrow_factor.py)
j_matrix in wavefunction_data.jastrow_data.jastrow_three_body_data (jastrow_factor.py)
- Atomic positions related:
positions in hamiltonian_data.structure_data (this file)
positions in wavefunction_data.geminal_data.mos_data/aos_data.structure_data (molecular_orbital.py/atomic_orbital.py)
positions in wavefunction_data.jastrow_data.jastrow_three_body_data.mos_data/aos_data.structure_data (jastrow_factor.py)
positions in Coulomb_potential_data.structure_data (coulomb_potential.py)
- accumulate_position_grad(grad_hamiltonian)#
Aggregate position gradients from Hamiltonian components (structure + wavefunction).
- Parameters:
grad_hamiltonian (Hamiltonian_data)
- coulomb_potential_data: Coulomb_potential_data#
- static load_from_hdf5(filepath='jqmc.h5')#
Load Hamiltonian data from an HDF5 file.
- Parameters:
filepath (
str, optional) – file path- Returns:
An instance of Hamiltonian_data.
- Return type:
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- save_to_hdf5(filepath='jqmc.h5')#
Save Hamiltonian data to an HDF5 file.
- Parameters:
filepath (
str, optional) – file path- Return type:
None
- structure_data: Structure_data#
- wavefunction_data: Wavefunction_data#
- jqmc.hamiltonians.compute_local_energy(hamiltonian_data, r_up_carts, r_dn_carts, RT)#
Compute Local Energy.
The method is for computing the local energy at (r_up_carts, r_dn_carts).
- Parameters:
hamiltonian_data (
Hamiltonian_data) – an instance of Hamiltonian_datar_up_carts (
jnpt.ArrayLike) – Cartesian coordinates of up-spin electrons (dim: N_e^{up}, 3)r_dn_carts (
jnpt.ArrayLike) – Cartesian coordinates of dn-spin electrons (dim: N_e^{dn}, 3)RT (
jnpt.ArrayLike) – Rotation matrix. equiv R.T used for non-local part. It does not affect all-electron calculations.
- Returns:
The value of local energy (e_L) with the given wavefunction (float)
- Return type:
float
jqmc.jastrow_factor module#
Jastrow module.
- class jqmc.jastrow_factor.Jastrow_NN_data(nn_def=None, params=None, flat_shape=(), num_params=0, treedef=None, shapes=<factory>, hidden_dim=64, num_layers=3, num_rbf=16, cutoff=5.0, num_species=0, species_lookup=(0, ), species_values=<factory>, structure_data=None)#
Bases:
objectContainer for NN-based Jastrow factor.
This dataclass stores both the neural network definition and its parameters, together with helper functions that integrate the NN Jastrow term into the variational-parameter block machinery.
The intended usage is:
nn_defholds a Flax/SchNet-like module (e.g. NNJastrow).paramsholds the corresponding PyTree of parameters.flatten_fn/unflatten_fnconvert between the PyTree and a1D parameter vector for SR/MCMC.
- If this dataclass is set to
NoneinsideJastrow_data, the NN contribution is simply turned off. If it is not
None, its contribution is evaluated and added on top of the analytic three-body Jastrow (if present).
- If this dataclass is set to
- Parameters:
nn_def (Any)
params (Any)
flat_shape (tuple[int, ...])
num_params (int)
treedef (Any)
shapes (list[tuple[int, ...]])
hidden_dim (int)
num_layers (int)
num_rbf (int)
cutoff (float)
num_species (int)
species_lookup (tuple[int, ...])
species_values (tuple[int, ...])
structure_data (Structure_data | None)
- cutoff: float = 5.0#
Radial cutoff for features.
- flat_shape: tuple[int, ...] = ()#
Shape of flattened params.
- property flatten_fn: Callable[[Any], Array]#
Return a flatten function built from
treedef.This is constructed on each access and is not part of the serialized state (so it will not cause pickle errors).
Hidden width used in NNJastrow.
- classmethod init_from_structure(structure_data, hidden_dim=64, num_layers=3, num_rbf=16, cutoff=5.0, key=None)#
Initialize NN Jastrow from structure information.
This creates a PauliNet-style NNJastrow module, initializes its parameters with a dummy electron configuration, and prepares flatten/unflatten utilities for SR/MCMC.
- Parameters:
structure_data (Structure_data)
hidden_dim (int)
num_layers (int)
num_rbf (int)
cutoff (float)
- Return type:
- nn_def: Any = None#
Flax module definition (e.g., NNJastrow).
- num_layers: int = 3#
Number of PauliNet blocks.
- num_params: int = 0#
Total number of parameters.
- num_rbf: int = 16#
PhysNet radial basis size.
- num_species: int = 0#
Count of unique nuclear species.
- params: Any = None#
Parameter PyTree for
nn_def.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- shapes: list[tuple[int, ...]]#
Per-leaf shapes.
- species_lookup: tuple[int, ...] = (0,)#
Lookup table mapping Z to species ids.
- species_values: tuple[int, ...]#
Sorted unique atomic numbers used.
- structure_data: Structure_data | None = None#
Structure info required for NN evaluation.
- treedef: Any = None#
PyTree treedef for params.
- property unflatten_fn: Callable[[Array], Any]#
Return an unflatten function built from
treedefandshapes.As with
flatten_fn(), this is constructed on each access and not stored inside the pickled state.
- class jqmc.jastrow_factor.Jastrow_data(jastrow_one_body_data=None, jastrow_two_body_data=None, jastrow_three_body_data=None, jastrow_nn_data=None)#
Bases:
objectJastrow dataclass.
The class contains data for evaluating a Jastrow function.
- Parameters:
jastrow_one_body_data (
Jastrow_one_body_data) – An instance of Jastrow_one_body_data. If None, the one-body Jastrow is turned off.jastrow_two_body_data (
Jastrow_two_body_data) – An instance of Jastrow_two_body_data. If None, the two-body Jastrow is turned off.jastrow_three_body_data (
Jastrow_three_body_data) – An instance of Jastrow_three_body_data. if None, the three-body Jastrow is turned off.jastrow_nn_data (
Jastrow_NN_data | None) – Optional container for a NN-based three-body Jastrow term. If None, the Jastrow NN contribution is turned off.
- accumulate_position_grad(grad_jastrow)#
Aggregate position gradients from all active Jastrow components.
- Parameters:
grad_jastrow (Jastrow_data)
- apply_block_update(block)#
Apply a single variational-parameter block update to this Jastrow object.
This method is the Jastrow-specific counterpart of
Wavefunction_data.apply_block_updates(). It receives a genericVariationalParameterBlockwhosevalueshave already been updated (typically byblock.apply_updateinside the SR/MCMC driver), and interprets that block according to Jastrow semantics.Responsibilities of this method are:
Map the block name (e.g.
"j1_param","j2_param","j3_matrix") to the corresponding internal Jastrow field(s).Enforce Jastrow-specific structural constraints when copying the block values into the internal arrays. In particular, for the three-body Jastrow term (J3) this includes:
Handling the case where only the last column is variational and the rest of the matrix is constrained.
Handling the fully square J3 matrix case.
Enforcing the required symmetry of the square J3 block.
By keeping all J1/J2/J3 interpretation and constraints in this method (and in the surrounding
Jastrow_dataclass), the optimizer andVariationalParameterBlockremain completely structure-agnostic. To introduce a new Jastrow parameter, extend the block construction inWavefunction_data.get_variational_blocksand add the corresponding handling here, without touching the SR/MCMC driver.- Parameters:
block (VariationalParameterBlock)
- Return type:
- collect_param_grads(grad_jastrow)#
Collect parameter gradients into a flat dict keyed by block name.
- Parameters:
grad_jastrow (Jastrow_data)
- Return type:
dict[str, object]
- jastrow_nn_data: Jastrow_NN_data | None = None#
Optional NN-based three-body Jastrow component.
- jastrow_one_body_data: Jastrow_one_body_data | None = None#
Optional one-body Jastrow component.
- jastrow_three_body_data: Jastrow_three_body_data | None = None#
Optional analytic three-body Jastrow component.
- jastrow_two_body_data: Jastrow_two_body_data | None = None#
Optional two-body Jastrow component.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- class jqmc.jastrow_factor.Jastrow_one_body_data(jastrow_1b_param=1.0, structure_data=<factory>, core_electrons=<factory>)#
Bases:
objectOne-body Jastrow parameters and structure metadata.
The one-body term models electron–nucleus correlations using the exponential form described in the original docstring. The numerical value is returned without the
expwrapper; callers attachexp(J)to the wavefunction.- Parameters:
jastrow_1b_param (
float) – Parameter controlling the one-body decay.structure_data (
Structure_data) – Nuclear positions and charges.core_electrons (
tuple[float]) – Removed core electrons per nucleus (for ECPs).
- core_electrons: list[float] | tuple[float]#
Effective core-electron counts aligned with
structure_data.
- classmethod init_jastrow_one_body_data(jastrow_1b_param, structure_data, core_electrons)#
Initialization.
- jastrow_1b_param: float = 1.0#
One-body Jastrow exponent parameter.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- structure_data: Structure_data#
Nuclear structure data providing positions and atomic numbers.
- class jqmc.jastrow_factor.Jastrow_three_body_data(orb_data=<factory>, j_matrix=<factory>)#
Bases:
objectThree-body Jastrow parameters and orbital references.
The three-body term uses the original matrix layout (square J3 block plus last-column J1-like vector). Values are returned without exponentiation; callers attach
exp(J)to the wavefunction. All existing functional details from the prior docstring are preserved.- Parameters:
orb_data (
AOs_sphe_data | AOs_cart_data | MOs_data) – Basis/orbital data used for both spins.j_matrix (
npt.NDArray | jax.Array) – J matrix with shape(orb_num, orb_num + 1).
- property compute_orb_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#
Function for computing AOs or MOs.
The api method to compute AOs or MOs corresponding to instances stored in self.orb_data
- Returns:
The api method to compute AOs or MOs.
- Return type:
Callable
- Raises:
NotImplementedError – If the instances of orb_data is neither AOs_data nor MOs_data.
- classmethod init_jastrow_three_body_data(orb_data, random_init=False, random_scale=0.01, seed=None)#
Initialization.
- Parameters:
orb_data (AOs_sphe_data | AOs_cart_data | MOs_data) – Orbital container (AOs or MOs) used to size the J-matrix.
random_init (bool) – If True, initialize with small random values instead of zeros (for tests).
random_scale (float) – Upper bound of uniform sampler when random_init is True (default 0.01).
seed (int | None) – Optional seed for deterministic initialization when random_init is True.
- j_matrix: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array#
J3/J1 matrix; square block plus final column.
- orb_data: AOs_sphe_data | AOs_cart_data | MOs_data#
Orbital basis (AOs or MOs) shared across spins.
- property orb_num: int#
Get number of atomic orbitals.
- Returns:
get number of atomic orbitals.
- Return type:
int
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- class jqmc.jastrow_factor.Jastrow_two_body_data(jastrow_2b_param=1.0)#
Bases:
objectTwo-body Jastrow parameter container.
The two-body term uses the Pade functional form described in the existing docstrings. Values are returned without exponentiation; callers use
exp(J)when constructing the wavefunction.- Parameters:
jastrow_2b_param (
float) – Parameter for the two-body Jastrow part.
- classmethod init_jastrow_two_body_data(jastrow_2b_param=1.0)#
Initialization.
- jastrow_2b_param: float = 1.0#
Pade
aparameter for J2.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- class jqmc.jastrow_factor.NNJastrow(hidden_dim=64, num_layers=3, num_rbf=32, cutoff=5.0, species_lookup=None, num_species=None, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
ModulePauliNet-inspired NN that outputs a three-body Jastrow correction.
The network implements the iteration rules described in the PauliNet manuscript (Eq. 1–2). Electron embeddings \(\mathbf{x}_i^{(n)}\) are iteratively refined by three message channels:
(+ ): same-spin electrons, enforcing antisymmetry indirectly by keepingthe messages exchange-equivariant.
(- ): opposite-spin electrons, capturing pairing terms.(n): nuclei, represented by fixed species embeddings.
After
num_layersiterations the final electron embeddings are summed and fed through \(\eta_\theta\) to produce a symmetric correction that is added on top of the analytic three-body Jastrow.- Parameters:
hidden_dim (int)
num_layers (int)
num_rbf (int)
cutoff (float)
species_lookup (ndarray[tuple[Any, ...], dtype[int32]] | Array | tuple[int, ...] | None)
num_species (int | None)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- class PauliNetBlock(hidden_dim, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
ModuleSingle PauliNet message-passing iteration following Eq. (1).
Each block mixes three message channels per electron: same-spin
(+ ), opposite-spin(- ), and nucleus-electron(n). The sender network is shared across channels to match the PauliNet weight-tying scheme, while separate weighting/receiver networks parameterize the contribution of every channel.- Parameters:
hidden_dim (int)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- name: str | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()#
Instantiate the shared sender/receiver networks for this block.
- class PhysNetRadialLayer(num_rbf, cutoff, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
ModuleCuspless PhysNet-inspired radial features \(e_k(r)\).
The basis follows Eq. (3) in the PauliNet supplement with a PhysNet-style envelope that forces both the value and the derivative of each Gaussian to vanish at the cutoff and the origin. These features are reused across all message channels, ensuring consistent geometric encoding.
- Parameters:
num_rbf (int)
cutoff (float)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- cutoff: float#
- name: str | None = None#
- num_rbf: int#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- class TwoLayerMLP(width, out_dim, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
ModuleUtility MLP used for \(w_\theta\), \(h_\theta\), \(g_\theta\), and \(\eta_\theta\).
- Parameters:
width (int)
out_dim (int)
parent (Module | Scope | _Sentinel | None)
name (str | None)
- name: str | None = None#
- out_dim: int#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- width: int#
- cutoff: float = 5.0#
- name: str | None = None#
- num_layers: int = 3#
- num_rbf: int = 32#
- num_species: int | None = None#
- parent: Module | Scope | _Sentinel | None = None#
- scope: Scope | None = None#
- setup()#
Instantiate PauliNet components and validate required metadata.
- Raises:
ValueError – If
species_lookupornum_specieswere not provided via the host dataclass before module initialization.
- species_lookup: ndarray[tuple[Any, ...], dtype[int32]] | Array | tuple[int, ...] | None = None#
- jqmc.jastrow_factor.compute_Jastrow_one_body(jastrow_one_body_data, r_up_carts, r_dn_carts)#
Evaluate the one-body Jastrow $J_1$ (without
exp) for given coordinates.The original exponential form and usage remain unchanged: this routine returns the scalar
Jvalue; callers attachexp(J)to the wavefunction.- Parameters:
jastrow_one_body_data (Jastrow_one_body_data) – One-body Jastrow parameters and structure data.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
One-body Jastrow value (before exponentiation).
- Return type:
float
- jqmc.jastrow_factor.compute_Jastrow_part(jastrow_data, r_up_carts, r_dn_carts)#
Evaluate the total Jastrow
J = J1 + J2 + J3(without exponentiation).This preserves the original behavior: the returned scalar
Jexcludes theexpfactor; callers applyexp(J)to the wavefunction. Both the analytic three-body and optional NN three-body contributions are included.- Parameters:
jastrow_data (Jastrow_data) – Collection of active Jastrow components (J1/J2/J3/NN).
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Total Jastrow value before exponentiation.
- Return type:
float
- jqmc.jastrow_factor.compute_Jastrow_three_body(jastrow_three_body_data, r_up_carts, r_dn_carts)#
Evaluate the three-body Jastrow $J_3$ (analytic) without exponentiation.
This preserves the original functional form: the square J3 block couples electron pairs and the last column acts as a J1-like vector. Returned value is
J; attachexp(J)externally.- Parameters:
jastrow_three_body_data (Jastrow_three_body_data) – Three-body Jastrow parameters and orbitals.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Three-body Jastrow value (before exponentiation).
- Return type:
float
- jqmc.jastrow_factor.compute_Jastrow_two_body(jastrow_two_body_data, r_up_carts, r_dn_carts)#
Evaluate the two-body Jastrow $J_2$ (Pade form) without exponentiation.
The functional form and usage remain identical to the original docstring; this returns
Jand callers attachexp(J)to the wavefunction.- Parameters:
jastrow_two_body_data (Jastrow_two_body_data) – Two-body Jastrow parameter container.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Two-body Jastrow value (before exponentiation).
- Return type:
float
- jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_one_body(jastrow_one_body_data, r_up_carts, r_dn_carts)#
Analytic gradients and per-electron Laplacians for the one-body Jastrow.
- Parameters:
jastrow_one_body_data (Jastrow_one_body_data) – One-body Jastrow parameters and structure data.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Gradients for up/down electrons with shapes
(N_up, 3)and(N_dn, 3), Laplacians for up/down electrons with shapes(N_up,)and(N_dn,).- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_part(jastrow_data, r_up_carts, r_dn_carts)#
Per-electron gradients and Laplacians of the full Jastrow $J$.
Analytic paths are used for J1/J2/J3 when available; the NN three-body term (if present) is handled via autodiff. Values are returned per electron (not summed) to match downstream kinetic-energy estimators.
- Parameters:
jastrow_data (Jastrow_data) – Active Jastrow components.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Gradients for up/down electrons with shapes
(N_up, 3)and(N_dn, 3)and Laplacians for up/down electrons with shapes(N_up,)and(N_dn,).- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_three_body(jastrow_three_body_data, r_up_carts, r_dn_carts)#
Analytic gradients and Laplacians for the three-body Jastrow.
The functional form is unchanged; this routine leverages analytic AO/MO gradients and Laplacians. Per-electron derivatives are returned (not summed), matching kinetic-energy estimator expectations.
- Parameters:
jastrow_three_body_data (Jastrow_three_body_data) – Three-body Jastrow parameters and orbitals.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Gradients for up/down electrons with shapes
(N_up, 3)and(N_dn, 3), Laplacians for up/down electrons with shapes(N_up,)and(N_dn,).- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_two_body(jastrow_two_body_data, r_up_carts, r_dn_carts)#
Analytic gradients and Laplacians for the Pade two-body Jastrow.
Uses the unchanged functional form
J2(r) = r / (2 * (1 + a r))witha = jastrow_2b_param. Returns per-electron quantities (not summed).- Parameters:
jastrow_two_body_data (Jastrow_two_body_data) – Two-body Jastrow parameter container.
r_up_carts (Array) – Spin-up electron coordinates with shape
(N_up, 3).r_dn_carts (Array) – Spin-down electron coordinates with shape
(N_dn, 3).
- Returns:
Gradients for up/down electrons with shapes
(N_up, 3)and(N_dn, 3), Laplacians for up/down electrons with shapes(N_up,)and(N_dn,).- Return type:
tuple[jax.Array, jax.Array, jax.Array, jax.Array]
- jqmc.jastrow_factor.compute_ratio_Jastrow_part(jastrow_data, old_r_up_carts, old_r_dn_carts, new_r_up_carts_arr, new_r_dn_carts_arr)#
Compute $exp(J(mathbf r’))/exp(J(mathbf r))$ for batched moves.
This follows the original ratio logic (including exp) while updating types to use
jax.Arrayinputs. The return is one ratio per proposed grid configuration.- Parameters:
jastrow_data (Jastrow_data) – Active Jastrow components.
old_r_up_carts (Array) – Reference spin-up coordinates with shape
(N_up, 3).old_r_dn_carts (Array) – Reference spin-down coordinates with shape
(N_dn, 3).new_r_up_carts_arr (Array) – Proposed spin-up coordinates with shape
(N_grid, N_up, 3).new_r_dn_carts_arr (Array) – Proposed spin-down coordinates with shape
(N_grid, N_dn, 3).
- Returns:
Jastrow ratios per grid with shape
(N_grid,)(includesexp).- Return type:
jax.Array
jqmc.jqmc_cli module#
command-line module.
jqmc.jqmc_gfmc module#
QMC module.
- class jqmc.jqmc_gfmc.GFMC_n(hamiltonian_data=None, num_walkers=40, num_mcmc_per_measurement=16, num_gfmc_collect_steps=5, mcmc_seed=34467, E_scf=0.0, alat=0.1, random_discretized_mesh=True, non_local_move='tmove', comput_position_deriv=False)#
Bases:
objectGFMC class. Runing GFMC with multiple walkers.
- Parameters:
hamiltonian_data (
Hamiltonian_data) – an instance of Hamiltonian_datanum_walkers (
int) – the number of walkersmcmc_seed (
int) – seed for the MCMC chain.E_scf (
float) – Self-consistent E (Hartree)alat (
float) – discretized grid length (bohr)random_discretized_mesh (
bool) – Flag for the random discretization mesh in the kinetic part and the non-local part of ECPs. Valid both for all-electron and ECP calculations.non_local_move (
str) – treatment of the spin-flip term. tmove (Casula’s T-move) or dtmove (Determinant Locality Approximation with Casula’s T-move) Valid only for ECP calculations. Do not specify this value for all-electron calculations.comput_position_deriv (
bool) – if True, compute the derivatives of E wrt. atomic positions.num_mcmc_per_measurement (int)
num_gfmc_collect_steps (int)
- property alat#
Return alat.
- property bare_w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored weight array. dim
- property comput_position_deriv: bool#
Return the flag for computing the derivatives of E wrt. atomic positions.
- property de_L_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored de_L/dR array. dim
- property de_L_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_electrons_dn, 3).
- Type:
Return the stored de_L/dr_dn array. dim
- property de_L_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_electrons_up, 3).
- Type:
Return the stored de_L/dr_up array. dim
- property dln_Psi_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_atoms, 3).
- Type:
Return the stored dln_Psi/dR array. dim
- property dln_Psi_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_electrons_dn, 3).
- Type:
Return the stored dln_Psi/dr_down array. dim
- property dln_Psi_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_electrons_up, 3).
- Type:
Return the stored dln_Psi/dr_up array. dim
- property domega_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_electons_dn, 3).
- Type:
Return the stored dOmega/dr_dn array. dim
- property domega_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_electons_dn, 3).
- Type:
Return the stored dOmega/dr_up array. dim
- property e_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored e_L array. dim
- property e_L2: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored e_L^2 array. dim
- get_E(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#
Return the mean and std of the computed local energy.
- Parameters:
num_mcmc_warmup_steps (
int) – the number of warmup steps.num_mcmc_bin_blocks (
int) – the number of binning blocks
- Returns:
The mean and std values of the totat energy and those of the variance estimated by the Jackknife method with the Args. (E_mean, E_std, Var_mean, Var_std).
- Return type:
tuple[float, float, float, float]
- get_aF(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#
Return the mean and std of the computed atomic forces.
- Parameters:
num_mcmc_warmup_steps (
int) – the number of warmup steps.num_mcmc_bin_blocks (
int) – the number of binning blocks
- Returns:
The mean and std values of the computed atomic forces estimated by the Jackknife method with the Args. The dimention of the arrays is (N, 3).
- Return type:
tuple[npt.NDArray, npt.NDArray]
- property hamiltonian_data#
Return hamiltonian_data.
- property mcmc_counter: int#
Return current MCMC counter.
- property num_gfmc_collect_steps#
Return num_gfmc_collect_steps.
- property num_walkers#
The number of walkers.
- property omega_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter,1, num_atoms, num_electons_dn).
- Type:
Return the stored Omega (for down electrons) array. dim
- property omega_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1, num_atoms, num_electrons_up).
- Type:
Return the stored Omega (for up electrons) array. dim
- run(num_mcmc_steps=50, max_time=86400)#
Run LRDMC with multiple walkers.
- Parameters:
num_branching (
int) – number of branching (reconfiguration of walkers).max_time (
int) – maximum time in sec.num_mcmc_steps (int)
- Return type:
None
- property w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored weight array. dim
- class jqmc.jqmc_gfmc.GFMC_t(hamiltonian_data=None, num_walkers=40, num_gfmc_collect_steps=5, mcmc_seed=34467, tau=0.1, alat=0.1, random_discretized_mesh=True, non_local_move='tmove')#
Bases:
objectGFMC class.
GFMC class. Runing GFMC.
- Parameters:
hamiltonian_data (
Hamiltonian_data) – an instance of Hamiltonian_datanum_walkers (
int) – the number of walkersnum_gfmc_collect_steps (
int) – the number of steps to collect the GFMC datamcmc_seed (
int) – seed for the MCMC chain.tau (
float) – projection time (bohr^-1)alat (
float) – discretized grid length (bohr)random_discretized_mesh (
bool) – Flag for the random discretization mesh in the kinetic part and in the non-local part of ECPs. Valid both for all-electron and ECP calculations.non_local_move (
str) – treatment of the spin-flip term. tmove (Casula’s T-move) or dtmove (Determinant Locality Approximation with Casula’s T-move) Valid only for ECP calculations. All-electron calculations, do not specify this value.
- property alat#
Return alat.
- property bare_w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored weight array. dim
- property e_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored e_L array. dim
- property e_L2: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored e_L2 array. dim
- get_E(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#
Return the mean and std of the computed local energy.
- Parameters:
num_mcmc_warmup_steps (
int) – the number of warmup steps.num_mcmc_bin_blocks (
int) – the number of binning blocks
- Returns:
The mean and std values of the totat energy and those of the variance estimated by the Jackknife method with the Args. (E_mean, E_std, Var_mean, Var_std).
- Return type:
tuple[float, float, float, float]
- property hamiltonian_data#
Return hamiltonian_data.
- property mcmc_counter: int#
Return current MCMC counter.
- property num_gfmc_collect_steps#
Return num_gfmc_collect_steps.
- property num_walkers#
The number of walkers.
- run(num_mcmc_steps=50, max_time=86400)#
Run LRDMC with multiple walkers.
- Parameters:
num_branching (
int) – number of branching (reconfiguration of walkers).max_time (
int) – maximum time in sec.num_mcmc_steps (int)
- Return type:
None
- property w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, 1).
- Type:
Return the stored weight array. dim
jqmc.jqmc_mcmc module#
QMC module.
- class jqmc.jqmc_mcmc.MCMC(hamiltonian_data=None, mcmc_seed=34467, num_walkers=40, num_mcmc_per_measurement=16, Dt=2.0, epsilon_AS=0.1, comput_param_deriv=False, comput_position_deriv=False, random_discretized_mesh=True)#
Bases:
objectProduction VMC/MCMC driver with multiple walkers.
This class drives Metropolis–Hastings sampling for many independent walkers in parallel (vectorized with
jax.vmap) and stores all observables needed by downstream analysis and optimization. All public methods are part of the supported API; private helpers are internal and subject to change.- Parameters:
hamiltonian_data (Hamiltonian_data)
mcmc_seed (int)
num_walkers (int)
num_mcmc_per_measurement (int)
Dt (float)
epsilon_AS (float)
comput_param_deriv (bool)
comput_position_deriv (bool)
random_discretized_mesh (bool)
- property comput_param_deriv: bool#
Return the flag for computing the derivatives of E wrt. variational parameters.
- property comput_position_deriv: bool#
Return the flag for computing the derivatives of E wrt. atomic positions.
- property de_L_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers).
- Type:
Return the stored de_L/dR array. dim
- property de_L_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_electrons_dn, 3).
- Type:
Return the stored de_L/dr_dn array. dim
- property de_L_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_electrons_up, 3).
- Type:
Return the stored de_L/dr_up array. dim
- property dln_Psi_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_atoms, 3).
- Type:
Return the stored dln_Psi/dR array. dim
- property dln_Psi_dc: dict[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]]#
Return stored parameter gradients keyed by block name.
- property dln_Psi_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_electrons_dn, 3).
- Type:
Return the stored dln_Psi/dr_down array. dim
- property dln_Psi_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_electrons_up, 3).
- Type:
Return the stored dln_Psi/dr_up array. dim
- property domega_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_electons_dn, 3).
- Type:
Return the stored dOmega/dr_dn array. dim
- property domega_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_electons_dn, 3).
- Type:
Return the stored dOmega/dr_up array. dim
- property e_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers).
- Type:
Return the stored e_L array. dim
- property e_L2: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers).
- Type:
Return the stored e_L^2 array. dim
- get_E(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#
Estimate total energy and variance with jackknife error bars.
- Parameters:
num_mcmc_warmup_steps (
int, optional) – Samples to discard as warmup. Defaults to 50.num_mcmc_bin_blocks (
int, optional) – Number of jackknife blocks. Defaults to 10.
- Returns:
(E_mean, E_std, Var_mean, Var_std)aggregated across MPI ranks.- Return type:
tuple[float, float, float, float]
- Raises:
ValueError – If there are insufficient post-warmup samples to form the requested blocks.
Notes
Warns when warmup or block counts fall below
MCMC_MIN_WARMUP_STEPS/MCMC_MIN_BIN_BLOCKS. All reductions are MPI-aware.
- get_aF(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#
Compute Hellmann–Feynman + Pulay forces with jackknife statistics.
- Parameters:
num_mcmc_warmup_steps (
int, optional) – Samples to drop for warmup. Defaults to 50.num_mcmc_bin_blocks (
int, optional) – Number of jackknife blocks. Defaults to 10.
- Returns:
(force_mean, force_std)shaped(num_atoms, 3)in Hartree/bohr.- Return type:
tuple[npt.NDArray, npt.NDArray]
Notes
Uses stored per-walker weights, energies, wavefunction gradients, and SWCT terms accumulated during
run(); reductions are MPI-aware.
- get_dln_WF(blocks, num_mcmc_warmup_steps=50, chosen_param_index=None)#
Assemble per-sample derivatives of ln Psi w.r.t. variational parameters.
- Parameters:
blocks (
list[VariationalParameterBlock]) – Ordered variational blocks used for concatenation.num_mcmc_warmup_steps (
int, optional) – Samples to discard as warmup. Defaults to 50.chosen_param_index (
list | None, optional) – Optional subset of flattened indices;Nonekeeps all. Defaults to None.
- Returns:
O_matrixwith shape(M, num_walkers, K)after warmup, whereKfollows the provided blocks (or subset).- Return type:
npt.NDArray
Notes
Validates the concatenated gradient size against block metadata and uses gradients stored during
run().
- get_gF(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10, chosen_param_index=None, blocks=None)#
Evaluate generalized forces (dE/dc_k) with jackknife error bars.
- Parameters:
num_mcmc_warmup_steps (
int, optional) – Samples to discard as warmup. Defaults to 50.num_mcmc_bin_blocks (
int, optional) – Number of jackknife blocks. Defaults to 10.chosen_param_index (
list | None, optional) – Optional subset of flattened indices. Defaults to None.blocks (
list | None, optional) – Variational blocks for parameter ordering; defaults to current wavefunction blocks.
- Returns:
(generalized_force_mean, generalized_force_std)as 1D vectors of lengthLafter any filtering.- Return type:
tuple[npt.NDArray, npt.NDArray]
Notes
Reuses
get_dln_WF()after warmup and applies jackknife statistics across MPI ranks.
- property hamiltonian_data#
Access the mutable
Hamiltonian_databacking this sampler.
- property mcmc_counter: int#
Number of Metropolis steps accumulated (rows in stored observables).
- property num_walkers#
The number of walkers.
- property omega_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_atoms, num_electons_dn).
- Type:
Return the stored Omega (for down electrons) array. dim
- property omega_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers, num_atoms, num_electrons_up).
- Type:
Return the stored Omega (for up electrons) array. dim
- run(num_mcmc_steps=0, max_time=86400)#
Execute Metropolis–Hastings sampling for all walkers.
- Parameters:
num_mcmc_steps (
int, optional) – Metropolis updates per walker; values <= 0 are no-ops. Defaults to 0.max_time (
int, optional) – Wall-clock budget in seconds. Defaults to 86400.
- Return type:
None
Notes
Creates
external_control_mcmc.tomlto allow external stop requests.Accumulates energies, weights, forces, and wavefunction gradients into public buffers (
w_L,e_L,dln_Psi_*etc.).Logs timing statistics and acceptance ratios at the end of the run.
- run_optimize(num_mcmc_steps=100, num_opt_steps=1, wf_dump_freq=10, max_time=86400, num_mcmc_warmup_steps=0, num_mcmc_bin_blocks=100, opt_J1_param=True, opt_J2_param=True, opt_J3_param=True, opt_JNN_param=True, opt_lambda_param=False, num_param_opt=0, optimizer_kwargs=None)#
Optimize wavefunction parameters using SR or optax.
- Parameters:
num_mcmc_steps (
int, optional) – MCMC samples per walker per iteration. Defaults to 100.num_opt_steps (
int, optional) – Number of optimization iterations. Defaults to 1.wf_dump_freq (
int, optional) – Dump frequency forhamiltonian_data.chk. Defaults to 10.max_time (
int, optional) – Per-iteration MCMC wall-clock budget (sec). Defaults to 86400.num_mcmc_warmup_steps (
int, optional) – Warmup samples discarded each iteration. Defaults to 0.num_mcmc_bin_blocks (
int, optional) – Jackknife bins for statistics. Defaults to 100.opt_J1_param (
bool, optional) – Optimize one-body Jastrow. Defaults to True.opt_J2_param (
bool, optional) – Optimize two-body Jastrow. Defaults to True.opt_J3_param (
bool, optional) – Optimize three-body Jastrow. Defaults to True.opt_JNN_param (
bool, optional) – Optimize NN Jastrow. Defaults to True.opt_lambda_param (
bool, optional) – Optimize determinant lambda matrix. Defaults to False.num_param_opt (
int, optional) – Limit parameters updated (ranked by|f|/|std f|);0means all. Defaults to 0.optimizer_kwargs (
dict | None, optional) – Optimizer configuration.method='sr'uses SR keys (delta,epsilon,cg_flag,cg_max_iter,cg_tol); othermethodnames are optax constructors (e.g.,"adam") and receive remaining keys.
Notes
Persists optax optimizer state across calls when method and hyperparameters match.
Writes
external_control_opt.tomlto allow external stop requests.Updates
Hamiltonian_datain-place and increments the optimization counter.
- property w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#
(mcmc_counter, num_walkers).
- Type:
Return the stored weight array. dim
jqmc.jqmc_miscs module#
jQMC miscs.
jqmc.jqmc_tool module#
jQMC tools.
- jqmc.jqmc_tool.hamiltonian_convert_wavefunction(hamiltonian_data_org_file=<typer.models.ArgumentInfo object>, convert_to=<typer.models.OptionInfo object>, hamiltonian_data_conv_file=<typer.models.OptionInfo object>)#
Convert wavefunction data in the Hamiltonian data.
- Parameters:
hamiltonian_data_org_file (str)
convert_to (ansatz_type)
hamiltonian_data_conv_file (str)
- jqmc.jqmc_tool.hamiltonian_show_info(hamiltonian_data=<typer.models.ArgumentInfo object>)#
Show information stored in the Hamiltonian data.
- Parameters:
hamiltonian_data (str)
- jqmc.jqmc_tool.hamiltonian_to_xyz(hamiltonian_data=<typer.models.ArgumentInfo object>, xyz_file=<typer.models.OptionInfo object>)#
Show information stored in the Hamiltonian data.
- Parameters:
hamiltonian_data (str)
xyz_file (str)
- jqmc.jqmc_tool.lrdmc_chk_fix(restart_chk=<typer.models.ArgumentInfo object>)#
LRDMC chk file fix.
- Parameters:
restart_chk (str)
- jqmc.jqmc_tool.lrdmc_compute_energy(restart_chk=<typer.models.ArgumentInfo object>, num_gfmc_bin_block=<typer.models.OptionInfo object>, num_gfmc_warmup_steps=<typer.models.OptionInfo object>, num_gfmc_collect_steps=<typer.models.OptionInfo object>)#
LRDMC energy calculation.
- Parameters:
restart_chk (str)
num_gfmc_bin_block (int)
num_gfmc_warmup_steps (int)
num_gfmc_collect_steps (int)
- jqmc.jqmc_tool.lrdmc_extrapolate_energy(restart_chks=<typer.models.ArgumentInfo object>, polynomial_order=<typer.models.OptionInfo object>, plot_graph=<typer.models.OptionInfo object>, save_graph=<typer.models.OptionInfo object>, num_gfmc_bin_block=<typer.models.OptionInfo object>, num_gfmc_warmup_steps=<typer.models.OptionInfo object>, num_gfmc_collect_steps=<typer.models.OptionInfo object>)#
LRDMC energy calculation.
- Parameters:
restart_chks (List[str])
polynomial_order (int)
plot_graph (bool)
save_graph (str)
num_gfmc_bin_block (int)
num_gfmc_warmup_steps (int)
num_gfmc_collect_steps (int)
- jqmc.jqmc_tool.lrdmc_generate_input(flag=<typer.models.OptionInfo object>, filename=<typer.models.OptionInfo object>, exclude_comment=<typer.models.OptionInfo object>)#
Generate an input file for LRDMC calculations.
- Parameters:
flag (bool)
filename (str)
exclude_comment (bool)
- jqmc.jqmc_tool.mcmc_chk_fix(restart_chk=<typer.models.ArgumentInfo object>)#
VMC chk file fix.
- Parameters:
restart_chk (str)
- jqmc.jqmc_tool.mcmc_compute_energy(restart_chk=<typer.models.ArgumentInfo object>, num_mcmc_bin_blocks=<typer.models.OptionInfo object>, num_mcmc_warmup_steps=<typer.models.OptionInfo object>)#
VMC energy calculation.
- Parameters:
restart_chk (str)
num_mcmc_bin_blocks (int)
num_mcmc_warmup_steps (int)
- jqmc.jqmc_tool.mcmc_generate_input(flag=<typer.models.OptionInfo object>, filename=<typer.models.OptionInfo object>, exclude_comment=<typer.models.OptionInfo object>)#
Generate an input file for VMC calculations.
- Parameters:
flag (bool)
filename (str)
exclude_comment (bool)
- class jqmc.jqmc_tool.orbital_type(value)#
Bases:
str,EnumOrbital type.
- ao = 'ao'#
- ao_full = 'ao-full'#
- ao_large = 'ao-large'#
- ao_medium = 'ao-medium'#
- ao_small = 'ao-small'#
- mo = 'mo'#
- none = 'none'#
- jqmc.jqmc_tool.trexio_convert_to(trexio_file=<typer.models.ArgumentInfo object>, hamiltonian_file=<typer.models.OptionInfo object>, j1_parmeter=<typer.models.OptionInfo object>, j2_parmeter=<typer.models.OptionInfo object>, j3_basis_type=<typer.models.OptionInfo object>, j_nn_type=<typer.models.OptionInfo object>, j_nn_params=<typer.models.OptionInfo object>)#
Convert a TREXIO file to hamiltonian_data.
- Parameters:
trexio_file (str)
hamiltonian_file (str)
j1_parmeter (float)
j2_parmeter (float)
j3_basis_type (orbital_type)
j_nn_type (str)
j_nn_params (List[str])
- jqmc.jqmc_tool.trexio_show_detail(filename=<typer.models.ArgumentInfo object>)#
Show information stored in the TREXIO file.
- Parameters:
filename (str)
- jqmc.jqmc_tool.trexio_show_info(filename=<typer.models.ArgumentInfo object>)#
Show information stored in the TREXIO file.
- Parameters:
filename (str)
- jqmc.jqmc_tool.vmc_analyze_output(filenames=<typer.models.ArgumentInfo object>, plot_graph=<typer.models.OptionInfo object>, save_graph=<typer.models.OptionInfo object>)#
Analyze the output files of vmc optimizations.
- Parameters:
filenames (List[str])
plot_graph (bool)
save_graph (str)
- jqmc.jqmc_tool.vmc_chk_fix(restart_chk=<typer.models.ArgumentInfo object>)#
VMCopt chk file fix.
- Parameters:
restart_chk (str)
- jqmc.jqmc_tool.vmc_generate_input(flag=<typer.models.OptionInfo object>, filename=<typer.models.OptionInfo object>, exclude_comment=<typer.models.OptionInfo object>)#
Generate an input file for VMCopt calculations.
- Parameters:
flag (bool)
filename (str)
exclude_comment (bool)
jqmc.jqmc_utility module#
utility module.
jqmc.molecular_orbital module#
Molecular Orbital module.
- class jqmc.molecular_orbital.MOs_data(num_mo=0, aos_data=<factory>, mo_coefficients=<factory>)#
Bases:
objectMolecular orbital (MO) coefficients and metadata.
Holds the contraction matrix that maps atomic orbitals (AOs) to molecular orbitals (MOs). MO values are obtained as
mo_coefficients @ AO_valuesin float64 (jax_enable_x64=True).- Variables:
num_mo (
int) – Number of molecular orbitals.aos_data (
AOs_sphe_data | AOs_cart_data) – AO definition supplying centers, exponents/coefficients, angular data, and contraction mapping.mo_coefficients (
npt.NDArray | jax.Array) – Coefficient matrix of shape(num_mo, num_ao). Rows correspond to MOs; columns correspond to contracted AOs.
- Parameters:
num_mo (int)
aos_data (AOs_sphe_data | AOs_cart_data)
mo_coefficients (ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex)
Examples
Minimal runnable setup (2 AOs -> 1 MO):
import numpy as np from jqmc.structure import Structure_data from jqmc.atomic_orbital import AOs_sphe_data from jqmc.molecular_orbital import MOs_data structure = Structure_data( positions=[[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]], pbc_flag=False, atomic_numbers=[1, 1], element_symbols=["H", "H"], atomic_labels=["H1", "H2"], ) aos = AOs_sphe_data( structure_data=structure, nucleus_index=[0, 1], num_ao=2, num_ao_prim=2, orbital_indices=[0, 1], exponents=[1.0, 1.2], coefficients=[1.0, 0.8], angular_momentums=[0, 0], magnetic_quantum_numbers=[0, 0], ) aos.sanity_check() mo_coeffs = np.array([[0.7, 0.7]], dtype=float) # shape (1, 2) mos = MOs_data(num_mo=1, aos_data=aos, mo_coefficients=mo_coeffs) mos.sanity_check()
- aos_data: AOs_sphe_data | AOs_cart_data#
AO definition supplying centers, exponents/coefficients, angular data, and contraction mapping.
- mo_coefficients: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex#
MO coefficient matrix, shape
(num_mo, num_ao).
- num_mo: int = 0#
Number of molecular orbitals.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Validate internal consistency.
Ensures
mo_coefficientsmatches(num_mo, aos_data.num_ao), verifiesnum_mois an int, and delegates AO validation toaos_data.sanity_check().- Raises:
ValueError – If coefficient shape or
num_motype is invalid, or ifaos_datafails its check.- Return type:
None
- property structure_data#
Return structure_data of the aos_data instance.
- jqmc.molecular_orbital.compute_MOs(mos_data, r_carts)#
Evaluate molecular orbitals at electron coordinates.
- Parameters:
mos_data (
MOs_data) – MO/AO definition and coefficient matrix.r_carts (
jax.Array) – Electron Cartesian coordinates, shape(N_e, 3)in float64 (same convention as AO evaluators).
- Returns:
MO values with shape
(num_mo, N_e).- Return type:
jax.Array
- jqmc.molecular_orbital.compute_MOs_grad(mos_data, r_carts)#
Compute MO gradients (x, y, z components) at electron coordinates.
- Parameters:
mos_data (
MOs_data) – MO/AO definition and coefficient matrix.r_carts (
jax.Array) – Electron Cartesian coordinates, shape(N_e, 3)in float64.
- Returns:
Gradients per component
(grad_x, grad_y, grad_z), each of shape(num_mo, N_e).- Return type:
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]
- jqmc.molecular_orbital.compute_MOs_laplacian(mos_data, r_carts)#
Compute MO laplacians at electron coordinates.
- Parameters:
mos_data (
MOs_data) – MO/AO definition and coefficient matrix.r_carts (
jax.Array) – Electron Cartesian coordinates, shape(N_e, 3)in float64.
- Returns:
Laplacians of each MO, shape
(num_mo, N_e).- Return type:
jax.Array
jqmc.setting module#
setting.
jqmc.structure module#
Structure module.
- class jqmc.structure.Structure_data(positions=<factory>, pbc_flag=False, vec_a=<factory>, vec_b=<factory>, vec_c=<factory>, atomic_numbers=<factory>, element_symbols=<factory>, atomic_labels=<factory>)#
Bases:
objectAtomic structure and cell metadata.
Stores Cartesian coordinates (Bohr), optional lattice vectors for PBC, and basic element metadata used by AOs/MOs, Coulomb, and Hamiltonian builders.
- Variables:
positions (
npt.NDArray | jax.Array) – Atomic Cartesian coordinates with shape(N, 3)in Bohr.pbc_flag (
bool) – Whether periodic boundary conditions are active. IfTrue, lattice vectorsvec_a|b|cmust be provided; otherwise they must be empty.vec_a (
list[float] | tuple[float]) – Lattice vector a (Bohr) whenpbc_flag=True.vec_b (
list[float] | tuple[float]) – Lattice vector b (Bohr) whenpbc_flag=True.vec_c (
list[float] | tuple[float]) – Lattice vector c (Bohr) whenpbc_flag=True.atomic_numbers (
list[int] | tuple[int]) – Atomic numbersZfor each site (lenN).element_symbols (
list[str] | tuple[str]) – Element symbols for each site (lenN).atomic_labels (
list[str] | tuple[str]) – Human-readable labels for each site (lenN).
- Parameters:
positions (ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex)
pbc_flag (bool)
vec_a (list[float] | tuple[float])
vec_b (list[float] | tuple[float])
vec_c (list[float] | tuple[float])
atomic_numbers (list[int] | tuple[int])
element_symbols (list[str] | tuple[str])
atomic_labels (list[str] | tuple[str])
Examples
Minimal H2 setup (Bohr):
import numpy as np from jqmc.structure import Structure_data structure = Structure_data( positions=np.array([[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]]), pbc_flag=False, atomic_numbers=[1, 1], element_symbols=["H", "H"], atomic_labels=["H1", "H2"], ) structure.sanity_check()
- atomic_labels: list[str] | tuple[str]#
Human-readable labels per site (len
N).
- atomic_numbers: list[int] | tuple[int]#
Atomic numbers
Zper site (lenN).
- property cell: ndarray[tuple[Any, ...], dtype[float64]]#
Lattice vectors as a
(3, 3)matrix in Bohr ([a, b, c]).
- element_symbols: list[str] | tuple[str]#
Element symbols per site (len
N).
- property lattice_vec_a: tuple#
Return lattice vector A (in Bohr).
- Returns:
the lattice vector A (in Bohr).
- Return type:
tuple[np.float64]
- property lattice_vec_b: tuple#
Return lattice vector B (in Bohr).
- Returns:
the lattice vector B (in Bohr).
- Return type:
tuple[np.float64]
- property lattice_vec_c: tuple#
Return lattice vector C (in Bohr).
- Returns:
the lattice vector C (in Bohr).
- Return type:
tuple[np.float64]
- property natom: int#
The number of atoms in the system.
- Returns:
The number of atoms in the system.
- Return type:
int
- property norm_vec_a: float#
Return the norm of the lattice vector A (in Bohr).
- Returns:
the norm of the lattice vector A (in Bohr).
- Return type:
np.float64
- property norm_vec_b: float#
Return the norm of the lattice vector B (in Bohr).
- Returns:
the norm of the lattice vector C (in Bohr).
- Return type:
np.float64
- property norm_vec_c: float#
Return the norm of the lattice vector C (in Bohr).
- Returns:
the norm of the lattice vector C (in Bohr).
- Return type:
np.float64
- property ntyp: int#
The number of element types in the system.
- Returns:
The number of element types in the system.
- Return type:
int
- pbc_flag: bool = False#
Whether periodic boundary conditions are active.
- positions: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex#
Atomic Cartesian coordinates with shape
(N, 3)in Bohr.
- property recip_cell: ndarray[tuple[Any, ...], dtype[float64]]#
Reciprocal lattice vectors
(3, 3)in Bohr^{-1}.Uses the standard definition
\[G_a = 2\pi \frac{T_b \times T_c}{T_a \cdot (T_b \times T_c)}, \quad G_b = 2\pi \frac{T_c \times T_a}{T_b \cdot (T_c \times T_a)}, \quad G_c = 2\pi \frac{T_a \times T_b}{T_c \cdot (T_a \times T_b)},\]and asserts the orthonormality condition \(T_i \cdot G_j = 2\pi\,\delta_{ij}\).
- property recip_vec_a: tuple#
Return reciprocal lattice vector A (in Bohr).
- Returns:
the reciprocal lattice vector A (in Bohr).
- Return type:
tuple[np.float64]
- property recip_vec_b: tuple#
Return reciprocal lattice vector B (in Bohr).
- Returns:
the reciprocal lattice vector B (in Bohr).
- Return type:
tuple[np.float64]
- property recip_vec_c: tuple#
Return reciprocal lattice vector C (in Bohr).
- Returns:
the reciprocal lattice vector C (in Bohr).
- Return type:
tuple[np.float64]
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Validate consistency of positions, labels, and lattice metadata.
Ensures all per-atom arrays share length
N, lattice vectors are either all present (length 3 each) whenpbc_flag=Trueor all empty whenpbc_flag=False, and basic types are as expected.- Raises:
ValueError – If list lengths mismatch, lattice vectors are inconsistent with
pbc_flag, or field types are incorrect.- Return type:
None
- vec_a: list[float] | tuple[float]#
Lattice vector a in Bohr (requires
pbc_flag=True).
- vec_b: list[float] | tuple[float]#
Lattice vector b in Bohr (requires
pbc_flag=True).
- vec_c: list[float] | tuple[float]#
Lattice vector c in Bohr (requires
pbc_flag=True).
jqmc.swct module#
SWCT module.
- class jqmc.swct.SWCT_data(structure)#
Bases:
objectSpace-warp coordinate transformation metadata.
- Parameters:
structure (
Structure_data) – Nuclear geometry used to build SWCT weights.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- structure: Structure_data#
- jqmc.swct.evaluate_swct_domega(swct_data, r_carts)#
Evaluate \(\sum_i \nabla_{r_i} \omega_{\alpha i}\) for each atom.
- Parameters:
swct_data (
SWCT_data) – Structure and cached geometry information.r_carts (
jax.Array) – Electron Cartesian coordinates with shape(N_e, 3)andfloat64dtype.
- Returns:
Sum of gradients per atom with shape
(N_a, 3).- Return type:
jax.Array
- jqmc.swct.evaluate_swct_omega(swct_data, r_carts)#
Compute SWCT weights \(\omega_{\alpha i}\) for each atom/electron pair.
- Parameters:
swct_data (
SWCT_data) – Structure and cached geometry information.r_carts (
jax.Array) – Electron Cartesian coordinates with shape(N_e, 3)andfloat64dtype.
- Returns:
Normalized weights with shape
(N_a, N_e), summing to 1 over atoms for each electron.- Return type:
jax.Array
jqmc.trexio_wrapper module#
TREXIO wrapper modules.
- jqmc.trexio_wrapper.read_trexio_file(trexio_file, store_tuple=False)#
Load a TREXIO HDF5 file into jqmc data containers.
- Parameters:
trexio_file (
str) – Path to the TREXIO file to read (HDF5 backend expected).store_tuple (
bool, optional) – Store list-like fields as tuples for immutability (useful in tests with JAX/Flax), at the cost of slower production runs. Defaults toFalse.
- Returns:
(structure_data, aos_data, mos_data_up, mos_data_dn, geminal_data, coulomb_potential_data)where:structure_datais a Structure_data describing atoms and geometry.aos_datais either AOs_cart_data or AOs_sphe_data depending on the basis.mos_data_upandmos_data_dnare MOs_data for spin-up/down orbitals.geminal_datais a Geminal_data assembled from the MO block.coulomb_potential_datais a Coulomb_potential_data describing (E)CPs.
- Return type:
tuple
- Raises:
NotImplementedError – If periodic cells (PBC) or complex molecular orbitals are encountered.
ValueError – If atomic labels are unsupported or AO counts are inconsistent.
Notes
Periodic boundary conditions are parsed but not supported yet.
Molecular orbitals are assumed real-valued; complex coefficients are rejected.
Examples
>>> from jqmc.trexio_wrapper import read_trexio_file >>> structure_data, aos_data, mos_up, mos_dn, geminal_data, coulomb_data = read_trexio_file("molecule.h5") >>> structure_data.atomic_labels[:3] ['O', 'H', 'H']
jqmc.wavefunction module#
Wavefunction module.
- class jqmc.wavefunction.VariationalParameterBlock(name, values, shape, size)#
Bases:
objectA block of variational parameters (e.g., J1, J2, J3, lambda).
Design overview#
A block is the smallest unit that the optimizer (MCMC + SR) sees. Each block corresponds to a contiguous slice in the global variational parameter vector and carries enough metadata to reconstruct its original shape (name, values, shape, size).
This class is intentionally structure-agnostic: it does not know anything about Jastrow vs Geminal, matrix symmetry, or how a block maps to concrete fields in
Jastrow_dataorGeminal_data.All physics- and structure-specific semantics are owned by the corresponding data classes via their
get_variational_blocksandapply_block_updatemethods.
The goal is that adding or modifying a variational parameter only requires changes on the wavefunction side (Jastrow/Geminal data), while the MCMC/SR driver remains completely agnostic and operates purely on a list of blocks.
- apply_update(delta_flat, learning_rate)#
Return a new block with values updated by a generic additive rule.
This method is intentionally structure-agnostic and only performs a simple additive update:
X_new = X_old + learning_rate * delta
Any parameter-specific constraints (e.g., symmetry of J3 or
lambda_matrix) must be enforced by the owner of the parameter (jastrow_data,geminal_data, etc.) inside theirapply_block_updateimplementations.- Parameters:
delta_flat (ndarray[tuple[Any, ...], dtype[_ScalarT]]) – Flattened update vector with length equal to
size.learning_rate (float) – Scaling factor for the update.
- Return type:
- name: str#
Identifier for this block (for example
"j1_param"or"lambda_matrix").
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- shape: tuple[int, ...]#
Original shape of
valuesfor unflattening updates.
- size: int#
Flattened size of
valuesused when slicing the global vector.
- values: Array | ndarray | bool | number | bool | int | float | complex#
Parameter payload (keeps PyTree structure if present).
- Parameters:
name (str)
values (Array | ndarray | bool | number | bool | int | float | complex)
shape (tuple[int, ...])
size (int)
- class jqmc.wavefunction.Wavefunction_data(jastrow_data=<factory>, geminal_data=<factory>)#
Bases:
objectContainer for Jastrow and Geminal parts used to evaluate a wavefunction.
The class owns only the data needed to construct the wavefunction. All computations are delegated to the functions in this module and the underlying Jastrow/Geminal helpers.
- Parameters:
jastrow_data (Jastrow_data) – Optional Jastrow parameters. If
None, the Jastrow part is omitted.geminal_data (Geminal_data) – Optional Geminal parameters. If
None, the determinant part is omitted.
- accumulate_position_grad(grad_wavefunction)#
Aggregate position gradients from geminal and Jastrow parts.
- Parameters:
grad_wavefunction (Wavefunction_data)
- apply_block_updates(blocks, thetas, learning_rate)#
Return a new
Wavefunction_datawith variational blocks updated.Design notes#
blocksdefines the ordering and shapes of all variational parameters;thetasis a single flattened update vector in the same order.This method is responsible for slicing
thetasinto per-block pieces and performing a generic additive update viaVariationalParameterBlock.apply_update().The interpretation of each block (“this is J1”, “this is the J3 matrix”, “this is lambda”) and any structural constraints (symmetry, rectangular layout, etc.) are delegated to
Jastrow_data.apply_block_update()andGeminal_data.apply_block_update().
Because of this separation of concerns, the MCMC/SR driver only needs to work with the flattened
thetasvector and the list of blocks; it never touches Jastrow/Geminal internals directly. To add a new parameter to the optimization, one only needs to (1) expose it inget_variational_blocks(), and (2) handle it in the correspondingapply_block_updatemethod.- Parameters:
blocks (list[VariationalParameterBlock])
thetas (ndarray[tuple[Any, ...], dtype[_ScalarT]])
learning_rate (float)
- Return type:
- collect_param_grads(grad_wavefunction)#
Collect parameter gradients from Jastrow and Geminal into a flat dict.
- Parameters:
grad_wavefunction (Wavefunction_data)
- Return type:
dict[str, object]
- flatten_param_grads(param_grads, num_walkers)#
Return parameter gradients as numpy arrays ready for storage.
The caller does not need to know the internal block structure (e.g., NN trees); any necessary flattening is handled here.
- Parameters:
param_grads (dict[str, object])
num_walkers (int)
- Return type:
dict[str, ndarray]
- geminal_data: Geminal_data#
Variational Geminal/determinant parameters.
- get_variational_blocks(opt_J1_param=True, opt_J2_param=True, opt_J3_param=True, opt_JNN_param=True, opt_lambda_param=False)#
Collect variational parameter blocks from Jastrow and Geminal parts.
Each block corresponds to a contiguous group of variational parameters (e.g., J1, J2, J3 matrix, NN Jastrow, lambda matrix). This method only exposes the parameter arrays; the corresponding gradients are handled on the MCMC side.
- Parameters:
opt_J1_param (bool)
opt_J2_param (bool)
opt_J3_param (bool)
opt_JNN_param (bool)
opt_lambda_param (bool)
- Return type:
- jastrow_data: Jastrow_data#
Variational Jastrow parameters.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- sanity_check()#
Check attributes of the class.
This function checks the consistencies among the arguments.
- Raises:
ValueError – If there is an inconsistency in a dimension of a given argument.
- Return type:
None
- with_diff_mask(*, params=True, coords=True)#
Return a copy with gradients masked according to the provided flags.
- Parameters:
params (bool)
coords (bool)
- Return type:
- with_param_grad_mask(*, opt_J1_param=True, opt_J2_param=True, opt_J3_param=True, opt_JNN_param=True, opt_lambda_param=True)#
Return a copy where disabled parameter blocks stop propagating gradients.
Developer note#
- The per-block flags (
opt_J1_parametc.) decide which high-level blocks are masked. Disabled blocks are wrapped with
DiffMask(params=False, coords=True), meaning parameter gradients are stopped while coordinate gradients still flow.
- The per-block flags (
- Within each disabled block,
apply_diff_maskuses field-name heuristics (see
diff_mask._PARAM_FIELD_NAMES) to tag parameter leaves such aslambda_matrix,j_matrix,jastrow_1b_param,jastrow_2b_param,jastrow_3b_param, andparams. Those tagged leaves receivejax.lax.stop_gradient, so their backpropagated gradients become zero.
- Within each disabled block,
- Example: if
opt_J1_param=Falseand others are True, only the J1 block is masked; its parameter leaves are stopped, while J2/J3/NN/lambda continue to propagate gradients normally.
- Example: if
- Parameters:
opt_J1_param (bool)
opt_J2_param (bool)
opt_J3_param (bool)
opt_JNN_param (bool)
opt_lambda_param (bool)
- Return type:
- jqmc.wavefunction.compute_discretized_kinetic_energy(alat, wavefunction_data, r_up_carts, r_dn_carts, RT)#
Compute discretized kinetic mesh points and energies for a given lattice spacing
alat.Function for computing discretized kinetic grid points and their energies with a given lattice space (alat). This keeps the original semantics used by the LRDMC path: ratios are computed as
exp(J_xp - J_x) * det_xp / det_x. Inputs are coerced to float64jax.Arraybefore evaluation.- Parameters:
alat (float) – Hamiltonian discretization (bohr), which will be replaced with
LRDMC_data.wavefunction_data – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Up-electron positions with shape
(n_up, 3).r_dn_carts (Array) – Down-electron positions with shape
(n_dn, 3).RT (Array) – Rotation matrix (\(R^T\)) with shape
(3, 3).
- Returns:
A tuple
(r_up_carts_combined, r_dn_carts_combined, elements_kinetic_part)where the combined coordinate arrays have shapes(n_grid, n_up, 3)and(n_grid, n_dn, 3)andelements_kinetic_partcontains the kinetic prefactor-scaled ratios.- Return type:
tuple[list[tuple[ndarray[tuple[Any, …], dtype[_ScalarT]], ndarray[tuple[Any, …], dtype[_ScalarT]]]], list[ndarray[tuple[Any, …], dtype[_ScalarT]]], Array]
- jqmc.wavefunction.compute_discretized_kinetic_energy_fast_update(alat, wavefunction_data, A_old_inv, r_up_carts, r_dn_carts, RT)#
Fast-update version of discretized kinetic mesh and ratios.
Function for computing discretized kinetic grid points and their energies with a given lattice space (alat). Uses precomputed
A_old_invto evaluate determinant ratios efficiently. Inputs are converted to float64jax.Arraybefore use.- Parameters:
alat (float) – Hamiltonian discretization (bohr), which will be replaced with
LRDMC_data.wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
A_old_inv (Array) – Inverse of the geminal matrix evaluated at
(r_up_carts, r_dn_carts).r_up_carts (Array) – Up-electron positions with shape
(n_up, 3).r_dn_carts (Array) – Down-electron positions with shape
(n_dn, 3).RT (Array) – Rotation matrix (\(R^T\)) with shape
(3, 3).
- Returns:
Tuple
(r_up_carts_combined, r_dn_carts_combined, elements_kinetic_part)with combined coordinate arrays of shapes(n_grid, n_up, 3)and(n_grid, n_dn, 3), and kinetic prefactor-scaled ratioselements_kinetic_part.- Return type:
tuple[Array, Array, Array]
- jqmc.wavefunction.compute_kinetic_energy(wavefunction_data, r_up_carts, r_dn_carts)#
Compute kinetic energy using analytic gradients and Laplacians.
The method is for computing kinetic energy of the given WF at
(r_up_carts, r_dn_carts)and fully exploits the JAX library for the kinetic energy calculation. Inputs are converted to float64jax.Arrayfor consistency with other compute utilities.- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Kinetic energy evaluated for the supplied configuration.
- Return type:
float | complex
- jqmc.wavefunction.compute_kinetic_energy_all_elements(wavefunction_data, r_up_carts, r_dn_carts)#
Analytic-derivative kinetic energy per electron (matches auto output shape).
Returns the per-electron kinetic energy using analytic gradients/Laplacians of both Jastrow and determinant parts. Shapes align with
_compute_kinetic_energy_all_elements_auto.- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Tuple of two
jax.Arrayobjects containing per-electron kinetic energies for spin-up and spin-down electrons, respectively.- Return type:
Array
- jqmc.wavefunction.compute_kinetic_energy_all_elements_fast_update(wavefunction_data, r_up_carts, r_dn_carts, geminal_inverse)#
Kinetic energy per electron using a precomputed geminal inverse.
- Parameters:
wavefunction_data (Wavefunction_data)
r_up_carts (Array)
r_dn_carts (Array)
geminal_inverse (Array)
- Return type:
Array
- jqmc.wavefunction.compute_quantum_force(wavefunction_data, r_up_carts, r_dn_carts)#
Compute quantum forces
2 * grad ln |Psi|at the given coordinates.The method is for computing quantum forces at
(r_up_carts, r_dn_carts). Gradients from the Jastrow part are currently set to zero (as in the original implementation); determinant gradients are included viacompute_grads_and_laplacian_ln_Det. Inputs are coerced to float64jax.Arrayfor consistency.- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Tuple
(force_up, force_dn)with shapes matching the input coordinate arrays.- Return type:
tuple[Array, Array]
- jqmc.wavefunction.evaluate_determinant(wavefunction_data, r_up_carts, r_dn_carts)#
Evaluate the determinant (Geminal) part of the wavefunction.
- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Determinant value evaluated at the supplied coordinates.
- Return type:
float
- jqmc.wavefunction.evaluate_jastrow(wavefunction_data, r_up_carts, r_dn_carts)#
Evaluate the Jastrow factor \(\exp(J)\) at the given coordinates.
The method is for evaluate the Jastrow part of the wavefunction (Psi) at
(r_up_carts, r_dn_carts). The returned value already includes the exponential, i.e.,exp(J).- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Real Jastrow factor
exp(J).- Return type:
float
- jqmc.wavefunction.evaluate_ln_wavefunction(wavefunction_data, r_up_carts, r_dn_carts)#
Evaluate the logarithm of
|wavefunction|(\(\ln |\Psi|\)).This follows the original behavior: compute the Jastrow part, multiply the determinant part, and then take
log(abs(det))while keeping the full Jastrow contribution. The inputs are converted to float64jax.Arrayfor downstream consistency.- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Scalar log-value of the wavefunction magnitude.
- Return type:
float
- jqmc.wavefunction.evaluate_wavefunction(wavefunction_data, r_up_carts, r_dn_carts)#
Evaluate the wavefunction
Psiat given electron coordinates.The method is for evaluate wavefunction (Psi) at
(r_up_carts, r_dn_carts)and returnsexp(Jastrow) * Determinant. Inputs are coerced to float64jax.Arrayto match other compute utilities.- Parameters:
wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).
r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape
(n_up, 3).r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape
(n_dn, 3).
- Returns:
Complex or real wavefunction value.
- Return type:
float | complex