jqmc package

Contents

jqmc package#

Submodules#

jqmc.atomic_orbital module#

Atomic Orbitals module.

Module containing classes and methods related to Atomic Orbitals

class jqmc.atomic_orbital.AOs_cart_data(structure_data=<factory>, nucleus_index=<factory>, num_ao=0, num_ao_prim=0, orbital_indices=<factory>, exponents=<factory>, coefficients=<factory>, angular_momentums=<factory>, polynominal_order_x=<factory>, polynominal_order_y=<factory>, polynominal_order_z=<factory>)#

Bases: object

Atomic orbital definitions in Cartesian form.

Stores contracted Gaussian basis data used to evaluate atomic orbitals (AOs) on a grid. The angular part is represented by Cartesian polynomials (\(x^{n_x} y^{n_y} z^{n_z}\)) with \(n_x + n_y + n_z = l\) for each AO.

Variables:
  • structure_data (Structure_data) – Molecular structure and atomic positions used to place AOs.

  • nucleus_index (list[int] | tuple[int]) – AO -> atom index mapping (len == num_ao).

  • num_ao (int) – Number of contracted AOs.

  • num_ao_prim (int) – Number of primitive Gaussians.

  • orbital_indices (list[int] | tuple[int]) – For each primitive, the parent AO index (len == num_ao_prim).

  • exponents (list[float] | tuple[float]) – Gaussian exponents for primitives (len == num_ao_prim).

  • coefficients (list[float] | tuple[float]) – Contraction coefficients per primitive (len == num_ao_prim).

  • angular_momentums (list[int] | tuple[int]) – Angular momentum quantum numbers l per AO (len == num_ao).

  • polynominal_order_x (list[int] | tuple[int]) – Cartesian power n_x for each AO (len == num_ao).

  • polynominal_order_y (list[int] | tuple[int]) – Cartesian power n_y for each AO (len == num_ao).

  • polynominal_order_z (list[int] | tuple[int]) – Cartesian power n_z for each AO (len == num_ao).

Parameters:
  • structure_data (Structure_data)

  • nucleus_index (list[int] | tuple[int])

  • num_ao (int)

  • num_ao_prim (int)

  • orbital_indices (list[int] | tuple[int])

  • exponents (list[float] | tuple[float])

  • coefficients (list[float] | tuple[float])

  • angular_momentums (list[int] | tuple[int])

  • polynominal_order_x (list[int] | tuple[int])

  • polynominal_order_y (list[int] | tuple[int])

  • polynominal_order_z (list[int] | tuple[int])

Examples

Minimal hydrogen dimer (bohr) with all-electron cc-pVTZ (Gaussian format) in Cartesian form:

from jqmc.structure import Structure_data
from jqmc.atomic_orbital import AOs_cart_data

structure = Structure_data(
    positions=[[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]],
    pbc_flag=False,
    atomic_numbers=[1, 1],
    element_symbols=["H", "H"],
    atomic_labels=["H1", "H2"],
)

# cc-pVTZ primitives duplicated per Cartesian component; counts:
# per atom -> 15 AOs, 19 primitives; for two atoms
num_ao=30; num_ao_prim=38

exponents = [
    0.3258,
    33.87, 5.095, 1.159, 0.3258, 0.1027,
    0.1027,
    1.407, 1.407, 1.407,
    0.388, 0.388, 0.388,
    1.057, 1.057, 1.057, 1.057, 1.057, 1.057,
    0.3258,
    33.87, 5.095, 1.159, 0.3258, 0.1027,
    0.1027,
    1.407, 1.407, 1.407,
    0.388, 0.388, 0.388,
    1.057, 1.057, 1.057, 1.057, 1.057, 1.057,
]

coefficients = [
    1.0,
    0.006068, 0.045308, 0.202822, 0.503903, 0.383421,
    1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
    1.0,
    0.006068, 0.045308, 0.202822, 0.503903, 0.383421,
    1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
]

orbital_indices = [
    0,                      # S shell (uncontracted)
    1, 1, 1, 1, 1,          # S shell (contracted)
    2,                      # S shell (uncontracted)
    3, 4, 5,                # P shell (uncontracted)
    6, 7, 8,                # P shell (uncontracted)
    9, 10, 11, 12, 13, 14,  # D shell (uncontracted)
    15,                     # S shell (uncontracted)
    16, 16, 16, 16, 16,     # S shell (contracted)
    17,                     # S shell (uncontracted)
    18, 19, 20,             # P shell (uncontracted)
    21, 22, 23,             # P shell (uncontracted)
    24, 25, 26, 27, 28, 29, # D shell (uncontracted)
]

nucleus_index = [
    *([0] * 15),
    *([1] * 15),
]

angular_momentums = [
    0, 0, 0,                # S shells (1 orbital each)
    1, 1, 1, 1, 1, 1,       # two P shells (3 orbitals each)
    2, 2, 2, 2, 2, 2,       # one D shell (6 orbitals each)
    0, 0, 0,                # S shells (1 orbital each)
    1, 1, 1, 1, 1, 1,       # two P shells (3 orbitals each)
    2, 2, 2, 2, 2, 2,       # one D shell (6 orbitals each)
]

polynominal_order_x = [
    0, 0, 0,
    1, 0, 0,
    1, 0, 0,
    2, 1, 1, 0, 0, 0,
    0, 0, 0,
    1, 0, 0,
    1, 0, 0,
    2, 1, 1, 0, 0, 0,
]

polynominal_order_y = [
    0, 0, 0,
    0, 1, 0,
    0, 1, 0,
    0, 1, 0, 2, 1, 0,
    0, 0, 0,
    0, 1, 0,
    0, 1, 0,
    0, 1, 0, 2, 1, 0,
]

polynominal_order_z = [
    0, 0, 0,
    0, 0, 1,
    0, 0, 1,
    0, 0, 1, 0, 1, 2,
    0, 0, 0,
    0, 0, 1,
    0, 0, 1,
    0, 0, 1, 0, 1, 2,
]

aos = AOs_cart_data(
    structure_data=structure,
    nucleus_index=nucleus_index,
    num_ao=num_ao,
    num_ao_prim=num_ao_prim,
    orbital_indices=orbital_indices,
    exponents=exponents,
    coefficients=coefficients,
    angular_momentums=angular_momentums,
    polynominal_order_x=polynominal_order_x,
    polynominal_order_y=polynominal_order_y,
    polynominal_order_z=polynominal_order_z,
)

aos.sanity_check()
angular_momentums: list[int] | tuple[int]#

Angular momentum quantum numbers l per AO (len == num_ao).

coefficients: list[float] | tuple[float]#

Contraction coefficients per primitive (len == num_ao_prim).

exponents: list[float] | tuple[float]#

Gaussian exponents for primitives (len == num_ao_prim).

nucleus_index: list[int] | tuple[int]#

AO -> atom index mapping (len == num_ao).

num_ao: int = 0#

Number of contracted AOs.

num_ao_prim: int = 0#

Number of primitive Gaussians.

orbital_indices: list[int] | tuple[int]#

For each primitive, the parent AO index (len == num_ao_prim).

polynominal_order_x: list[int] | tuple[int]#

Cartesian power n_x for each AO (len == num_ao).

polynominal_order_y: list[int] | tuple[int]#

Cartesian power n_y for each AO (len == num_ao).

polynominal_order_z: list[int] | tuple[int]#

Cartesian power n_z for each AO (len == num_ao).

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Validate AO array shapes and basic types.

Ensures that array lengths match declared counts (num_ao, num_ao_prim) and that inputs are provided as lists/tuples (ints for counts). Call this after constructing AOs_cart_data and before AO evaluation routines.

Raises:

ValueError – When any of the following holds: - len(nucleus_index) != num_ao - len(unique(orbital_indices)) != num_ao - len(exponents) != num_ao_prim or len(coefficients) != num_ao_prim - len(angular_momentums) != num_ao - len(polynominal_order_{x,y,z}) != num_ao - any attribute has an unexpected Python type (e.g., non-list/tuple arrays, non-int counts)

Return type:

None

structure_data: Structure_data#

Molecular structure and atomic positions used to place AOs.

class jqmc.atomic_orbital.AOs_sphe_data(structure_data=<factory>, nucleus_index=<factory>, num_ao=0, num_ao_prim=0, orbital_indices=<factory>, exponents=<factory>, coefficients=<factory>, angular_momentums=<factory>, magnetic_quantum_numbers=<factory>)#

Bases: object

Atomic orbital definitions in real spherical-harmonic form.

Stores contracted Gaussian basis data for atomic orbitals (AOs) whose angular part is a real spherical harmonic \(Y_{l}^{m}\) with \(m \in \{-l,\dots,+l\}\).

Variables:
  • structure_data (Structure_data) – Molecular structure and atomic positions used to place AOs.

  • nucleus_index (list[int] | tuple[int]) – AO -> atom index mapping (len == num_ao).

  • num_ao (int) – Number of contracted AOs.

  • num_ao_prim (int) – Number of primitive Gaussians.

  • orbital_indices (list[int] | tuple[int]) – For each primitive, the parent AO index (len == num_ao_prim).

  • exponents (list[float] | tuple[float]) – Gaussian exponents for primitives (len == num_ao_prim).

  • coefficients (list[float] | tuple[float]) – Contraction coefficients per primitive (len == num_ao_prim).

  • angular_momentums (list[int] | tuple[int]) – Angular momentum quantum numbers l per AO (len == num_ao).

  • magnetic_quantum_numbers (list[int] | tuple[int]) – Magnetic quantum numbers m per AO (len == num_ao), satisfying -l <= m <= l.

Parameters:
  • structure_data (Structure_data)

  • nucleus_index (list[int] | tuple[int])

  • num_ao (int)

  • num_ao_prim (int)

  • orbital_indices (list[int] | tuple[int])

  • exponents (list[float] | tuple[float])

  • coefficients (list[float] | tuple[float])

  • angular_momentums (list[int] | tuple[int])

  • magnetic_quantum_numbers (list[int] | tuple[int])

Examples

Hydrogen dimer (bohr) with all-electron cc-pVTZ (Gaussian format), real spherical harmonics:

from jqmc.structure import Structure_data
from jqmc.atomic_orbital import AOs_sphe_data

structure = Structure_data(
    positions=[[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]],
    pbc_flag=False,
    atomic_numbers=[1, 1],
    element_symbols=["H", "H"],
    atomic_labels=["H1", "H2"],
)

# Per atom: 14 AOs (3 S, 2×P shells -> 6, 1×D shell -> 5); 18 primitives.
# Two atoms -> num_ao=28, num_ao_prim=36.
exponents = [
    # atom 1
    0.3258,
    33.87, 5.095, 1.159, 0.3258, 0.1027,
    0.1027,
    1.407, 1.407, 1.407,
    0.388, 0.388, 0.388,
    1.057, 1.057, 1.057, 1.057, 1.057,
    # atom 2 (same order)
    0.3258,
    33.87, 5.095, 1.159, 0.3258, 0.1027,
    0.1027,
    1.407, 1.407, 1.407,
    0.388, 0.388, 0.388,
    1.057, 1.057, 1.057, 1.057, 1.057,
]

coefficients = [
    # atom 1
    1.0,
    0.006068, 0.045308, 0.202822, 0.503903, 0.383421,
    1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0, 1.0, 1.0,
    # atom 2 (same order)
    1.0,
    0.006068, 0.045308, 0.202822, 0.503903, 0.383421,
    1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0,
    1.0, 1.0, 1.0, 1.0, 1.0,
]

orbital_indices = [
    # atom 1 AOs 0-13
    0,
    1, 1, 1, 1, 1,
    2,
    3, 4, 5,
    6, 7, 8,
    9, 10, 11, 12, 13,
    # atom 2 AOs 14-27
    14,
    15, 15, 15, 15, 15,
    16,
    17, 18, 19,
    20, 21, 22,
    23, 24, 25, 26, 27,
]

nucleus_index = [
    *([0] * 14),
    *([1] * 14),
]

angular_momentums = [
    # atom 1
    0, 0, 0,                # S shells
    1, 1, 1, 1, 1, 1,       # two P shells (3 each)
    2, 2, 2, 2, 2,          # one D shell (5)
    # atom 2
    0, 0, 0,
    1, 1, 1, 1, 1, 1,
    2, 2, 2, 2, 2,
]

magnetic_quantum_numbers = [
    # atom 1 (S: m=0; P: -1,0,1; D: -2..2)
    0, 0, 0,
    -1, 0, 1,
    -1, 0, 1,
    -2, -1, 0, 1, 2,
    # atom 2
    0, 0, 0,
    -1, 0, 1,
    -1, 0, 1,
    -2, -1, 0, 1, 2,
]

aos = AOs_sphe_data(
    structure_data=structure,
    nucleus_index=nucleus_index,
    num_ao=len(angular_momentums),
    num_ao_prim=len(exponents),
    orbital_indices=orbital_indices,
    exponents=exponents,
    coefficients=coefficients,
    angular_momentums=angular_momentums,
    magnetic_quantum_numbers=magnetic_quantum_numbers,
)

aos.sanity_check()
angular_momentums: list[int] | tuple[int]#

Angular momentum quantum numbers l per AO (len == num_ao).

coefficients: list[float] | tuple[float]#

Contraction coefficients per primitive (len == num_ao_prim).

exponents: list[float] | tuple[float]#

Gaussian exponents for primitives (len == num_ao_prim).

magnetic_quantum_numbers: list[int] | tuple[int]#

Magnetic quantum numbers m per AO (len == num_ao; -l <= m <= l).

nucleus_index: list[int] | tuple[int]#

AO -> atom index mapping (len == num_ao).

num_ao: int = 0#

Number of contracted AOs.

num_ao_prim: int = 0#

Number of primitive Gaussians.

orbital_indices: list[int] | tuple[int]#

For each primitive, the parent AO index (len == num_ao_prim).

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Validate AO array shapes and basic types.

Ensures that array lengths match declared counts (num_ao, num_ao_prim) and that inputs are provided as lists/tuples (ints for counts). Call this after constructing AOs_sphe_data and before AO evaluation routines.

Raises:

ValueError – When any of the following holds: - len(nucleus_index) != num_ao - len(unique(orbital_indices)) != num_ao - len(exponents) != num_ao_prim or len(coefficients) != num_ao_prim - len(angular_momentums) != num_ao - len(magnetic_quantum_numbers) != num_ao - any attribute has an unexpected Python type (e.g., non-list/tuple arrays, non-int counts)

Return type:

None

structure_data: Structure_data#

Molecular structure and atomic positions used to place AOs.

jqmc.atomic_orbital.compute_AOs(aos_data, r_carts)#

Evaluate contracted atomic orbitals (AOs) at electron coordinates.

Dispatches to Cartesian or real-spherical backends and returns float64 JAX arrays (ensure jax_enable_x64=True). Call aos_data.sanity_check() before use.

Parameters:
  • aos_data (AOs_sphe_data | AOs_cart_data) – AOs_cart_data or AOs_sphe_data describing centers, primitive parameters, angular data, and contraction mapping.

  • r_carts (jax.Array) – Electron Cartesian coordinates, shape (N_e, 3) in Bohr. Casts to float64 internally via jnp.asarray.

Returns:

AO values, shape (num_ao, N_e).

Return type:

jax.Array

Raises:

NotImplementedError – If aos_data is neither Cartesian nor spherical.

jqmc.atomic_orbital.compute_AOs_grad(aos_data, r_carts)#

Return analytic Cartesian gradients of contracted atomic orbitals.

Public gradient API used for drift vectors and kinetic terms. Dispatches to Cartesian or real-spherical backends; returns float64 JAX arrays (ensure jax_enable_x64=True).

Parameters:
  • aos_data (AOs_sphe_data | AOs_cart_data) – AOs_cart_data or AOs_sphe_data describing primitive parameters, angular info, contraction mapping, and centers (run sanity_check() beforehand).

  • r_carts (jax.Array) – Electron Cartesian coordinates, shape (N_e, 3) (Bohr). Casts to float64 internally via jnp.asarray.

Returns:

Gradients w.r.t. x, y, z, each of shape (num_ao, N_e). Order is (gx, gy, gz).

Return type:

tuple[jax.Array, jax.Array, jax.Array]

Raises:

NotImplementedError – If aos_data is neither Cartesian nor spherical.

jqmc.atomic_orbital.compute_AOs_laplacian(aos_data, r_carts)#

Return analytic Laplacians of contracted atomic orbitals.

Dispatches to Cartesian or real-spherical implementations; returns float64 JAX arrays (ensure jax_enable_x64=True).

Parameters:
  • aos_data (AOs_sphe_data | AOs_cart_data) – AOs_cart_data or AOs_sphe_data describing centers, primitive exponents/coefficients, angular data, and contraction mapping (run sanity_check() beforehand).

  • r_carts (jax.Array) – Electron Cartesian coordinates, shape (N_e, 3) (Bohr). Casts to float64 internally via jnp.asarray.

Returns:

Laplacians of all contracted AOs, shape (num_ao, N_e).

Return type:

jax.Array

Raises:

NotImplementedError – If aos_data is not Cartesian or spherical.

jqmc.coulomb_potential module#

Effective core potential module.

Module containing classes and methods related to Effective core potential and bare Coulomb potentials

class jqmc.coulomb_potential.Coulomb_potential_data(structure_data=<factory>, ecp_flag=False, z_cores=<factory>, max_ang_mom_plus_1=<factory>, num_ecps=0, ang_moms=<factory>, nucleus_index=<factory>, exponents=<factory>, coefficients=<factory>, powers=<factory>)#

Bases: object

Container for bare Coulomb and effective core potential (ECP) parameters.

Parameters:
  • structure_data (Structure_data) – Underlying nuclear geometry and metadata.

  • ecp_flag (bool) – Whether ECPs are present. When True, all ECP arrays must be populated.

  • z_cores (list[float] | tuple[float]) – Core electrons removed per atom; length natom.

  • max_ang_mom_plus_1 (list[int] | tuple[int]) – l_max + 1 for each atom; length natom.

  • num_ecps (int) – Total number of ECP projector terms across all atoms and angular momenta.

  • ang_moms (list[int] | tuple[int]) – Angular momentum l per ECP term; length num_ecps.

  • nucleus_index (list[int] | tuple[int]) – Atom index per ECP term; length num_ecps.

  • exponents (list[float] | tuple[float]) – Gaussian exponents per ECP term; length num_ecps.

  • coefficients (list[float] | tuple[float]) – Prefactors per ECP term; length num_ecps.

  • powers (list[int] | tuple[int]) – Polynomial powers per ECP term; length num_ecps.

Notes

  • When ecp_flag is False, all ECP-related sequences must be empty and num_ecps should be 0.

  • Arrays are stored as Python lists/tuples for pytrees; conversion to jax.Array happens in the compute kernels.

ang_moms: list[int] | tuple[int]#

Angular momentum l per ECP term (len = num_ecps).

coefficients: list[float] | tuple[float]#

Prefactors per ECP term (len = num_ecps).

ecp_flag: bool = False#

Whether ECP parameters are active.

exponents: list[float] | tuple[float]#

Gaussian exponents per ECP term (len = num_ecps).

max_ang_mom_plus_1: list[int] | tuple[int]#

l_max + 1 per atom (len = natom).

nucleus_index: list[int] | tuple[int]#

Atom index per ECP term (len = num_ecps).

num_ecps: int = 0#

Total ECP projector terms across all atoms.

powers: list[int] | tuple[int]#

Polynomial powers per ECP term (len = num_ecps).

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

structure_data: Structure_data#

Nuclear geometry and atom metadata.

z_cores: list[float] | tuple[float]#

Core electrons removed per atom (len = natom).

jqmc.coulomb_potential.compute_bare_coulomb_potential(coulomb_potential_data, r_up_carts, r_dn_carts)#

Compute bare Coulomb interaction (ion–ion, electron–ion, electron–electron).

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – Structure and charges (effective if ECPs present).

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

Returns:

Total bare Coulomb energy.

Return type:

float

jqmc.coulomb_potential.compute_bare_coulomb_potential_el_el(r_up_carts, r_dn_carts)#

Electron–electron Coulomb interaction energy.

Parameters:
  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

Returns:

Electron–electron Coulomb energy.

Return type:

float

jqmc.coulomb_potential.compute_bare_coulomb_potential_el_ion(coulomb_potential_data, r_up_carts, r_dn_carts)#

Total electron–ion Coulomb interaction energy.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – Structure and charges (effective if ECPs present).

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

Returns:

Electron–ion Coulomb energy.

Return type:

float

jqmc.coulomb_potential.compute_bare_coulomb_potential_el_ion_element_wise(coulomb_potential_data, r_up_carts, r_dn_carts)#

Element-wise electron–ion Coulomb interactions.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – Structure and charges (effective if ECPs present).

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

Returns:

Element-wise ion–electron interactions for up spins and down spins (shape (N_up,) and (N_dn,)).

Return type:

tuple[jax.Array, jax.Array]

jqmc.coulomb_potential.compute_bare_coulomb_potential_ion_ion(coulomb_potential_data)#

Ion–ion Coulomb interaction energy.

Parameters:

coulomb_potential_data (Coulomb_potential_data) – Structure and charges (effective if ECPs present).

Returns:

Ion–ion Coulomb energy.

Return type:

float

jqmc.coulomb_potential.compute_coulomb_potential(coulomb_potential_data, r_up_carts, r_dn_carts, RT, NN=1, Nv=6, wavefunction_data=None)#

Compute total Coulomb energy including bare and ECP terms.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – Structure, charges, and ECP parameters.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • RT (jax.Array) – Rotation matrix applied to quadrature grid points (shape (3, 3)) for non-local ECP.

  • NN (int) – Number of nearest nuclei to include for each electron in the non-local term.

  • Nv (int) – Number of quadrature points on the sphere.

  • wavefunction_data (Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ECP ratios; required when ecp_flag is True.

Returns:

Sum of bare Coulomb (ion–ion, electron–ion, electron–electron) and ECP (local + non-local) energies.

Return type:

float

jqmc.coulomb_potential.compute_discretized_bare_coulomb_potential_el_ion_element_wise(coulomb_potential_data, r_up_carts, r_dn_carts, alat)#

Element-wise electron–ion Coulomb interactions with distance floor alat.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – Structure and charges (effective if ECPs present).

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • alat (float) – Minimum allowed distance to avoid divergence.

Returns:

Element-wise ion–electron interactions for up spins and down spins (shape (N_up,) and (N_dn,)).

Return type:

tuple[jax.Array, jax.Array]

jqmc.coulomb_potential.compute_ecp_coulomb_potential(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, NN=1, Nv=6)#

Compute total ECP energy (local + non-local) for a configuration.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – ECP parameters and structure data.

  • wavefunction_data (Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • RT (jax.Array) – Rotation matrix applied to quadrature grid points (shape (3, 3)).

  • NN (int) – Number of nearest nuclei to include for each electron in the non-local term.

  • Nv (int) – Number of quadrature points on the sphere.

Returns:

Sum of local and non-local ECP contributions for the given geometry.

Return type:

float

jqmc.coulomb_potential.compute_ecp_local_parts_all_pairs(coulomb_potential_data, r_up_carts, r_dn_carts)#

Compute local ECP contribution over all nucleus–electron pairs.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – ECP parameters and structure data.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

Returns:

Total local ECP energy for the provided electron coordinates.

Return type:

float

jqmc.coulomb_potential.compute_ecp_non_local_part_all_pairs_jax_weights_grid_points(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, weights, grid_points, flag_determinant_only=0)#

Vectorized non-local ECP projection over all pairs with provided quadrature.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – ECP parameters and structure data.

  • wavefunction_data (Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • weights (list[float]) – Quadrature weights for angular integration.

  • grid_points (npt.NDArray[np.float64]) – Quadrature grid points with shape (Nv, 3).

  • flag_determinant_only (int) – If 1, skip Jastrow in the wavefunction ratio; if 0, include it.

Returns:

  • Mesh-displaced up-spin coordinates per configuration.

  • Mesh-displaced down-spin coordinates per configuration.

  • Non-local contributions for up-spin mesh points.

  • Non-local contributions for down-spin mesh points.

  • Scalar sum of all non-local contributions.

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array, float]

jqmc.coulomb_potential.compute_ecp_non_local_parts_all_pairs(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, Nv=6, flag_determinant_only=False)#

Compute non-local ECP contribution considering all nucleus–electron pairs.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – ECP parameters and structure data.

  • wavefunction_data (Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • RT (jax.Array) – Rotation matrix applied to quadrature grid points (shape (3, 3)).

  • Nv (int) – Number of quadrature points on the sphere.

  • flag_determinant_only (bool) – If True, ignore Jastrow in the wavefunction ratio.

Returns:

  • Mesh-displaced r_up_carts per configuration.

  • Mesh-displaced r_dn_carts per configuration.

  • Non-local ECP contributions per configuration (flattened).

  • Scalar sum of all non-local contributions.

Return type:

tuple[list[jax.Array], list[jax.Array], jax.Array, float]

jqmc.coulomb_potential.compute_ecp_non_local_parts_nearest_neighbors(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, NN=1, Nv=6, flag_determinant_only=False)#

Compute non-local ECP contribution with a nearest-neighbor cutoff.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – ECP parameters and structure data.

  • wavefunction_data (Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • RT (jax.Array) – Rotation matrix applied to quadrature grid points (shape (3, 3)).

  • NN (int) – Number of nearest nuclei to include for each electron.

  • Nv (int) – Number of quadrature points on the sphere.

  • flag_determinant_only (bool) – If True, ignore Jastrow in the wavefunction ratio.

Returns:

  • Mesh-displaced r_up_carts per configuration.

  • Mesh-displaced r_dn_carts per configuration.

  • Non-local ECP contributions per configuration (flattened).

  • Scalar sum of all non-local contributions.

Return type:

tuple[list[jax.Array], list[jax.Array], jax.Array, float]

jqmc.coulomb_potential.compute_ecp_non_local_parts_nearest_neighbors_fast_update(coulomb_potential_data, wavefunction_data, r_up_carts, r_dn_carts, RT, A_old_inv, NN=1, Nv=6, flag_determinant_only=False)#

Fast-update variant of non-local ECP contributions (nearest neighbors).

This variant reuses the inverse geminal matrix to compute determinant ratios and uses Jastrow ratios, avoiding full recomputation for each mesh point.

Parameters:
  • coulomb_potential_data (Coulomb_potential_data) – ECP parameters and structure data.

  • wavefunction_data (Wavefunction_data) – Wavefunction (geminal + Jastrow) used for ratios.

  • r_up_carts (jax.Array) – Up-spin electron Cartesian coordinates with shape (N_up, 3) and float64 dtype.

  • r_dn_carts (jax.Array) – Down-spin electron Cartesian coordinates with shape (N_dn, 3) and float64 dtype.

  • RT (jax.Array) – Rotation matrix applied to quadrature grid points (shape (3, 3)).

  • A_old_inv (jax.Array) – Inverse geminal matrix evaluated at (r_up_carts, r_dn_carts).

  • NN (int) – Number of nearest nuclei to include for each electron.

  • Nv (int) – Number of quadrature points on the sphere.

  • flag_determinant_only (bool) – If True, ignore Jastrow in the wavefunction ratio.

Returns:

  • Mesh-displaced r_up_carts per configuration.

  • Mesh-displaced r_dn_carts per configuration.

  • Non-local ECP contributions per configuration (flattened).

  • Scalar sum of all non-local contributions.

Return type:

tuple[list[jax.Array], list[jax.Array], jax.Array, float]

jqmc.determinant module#

Determinant module.

class jqmc.determinant.Geminal_data(num_electron_up=0, num_electron_dn=0, orb_data_up_spin=<factory>, orb_data_dn_spin=<factory>, lambda_matrix=<factory>)#

Bases: object

Geminal (AGP) parameters and orbital references.

Parameters:
  • num_electron_up (int) – Number of spin-up electrons.

  • num_electron_dn (int) – Number of spin-down electrons.

  • orb_data_up_spin (AOs_data | MOs_data) – Basis/orbitals for spin-up electrons.

  • orb_data_dn_spin (AOs_data | MOs_data) – Basis/orbitals for spin-down electrons.

  • lambda_matrix (npt.NDArray | jax.Array) – Geminal pairing matrix with shape (orb_num_up, orb_num_dn + num_electron_up - num_electron_dn).

Notes

  • For closed shells, orb_num_up == orb_num_dn and lambda_matrix is square.

  • For open shells, the right block encodes unpaired spin-up orbitals.

accumulate_position_grad(grad_geminal)#

Aggregate position gradients from geminal-related structures.

Parameters:

grad_geminal (Geminal_data)

apply_block_update(block)#

Apply a single variational-parameter block update to this Geminal object.

This method is the Geminal-specific counterpart of Wavefunction_data.apply_block_updates(). It receives a generic VariationalParameterBlock whose values have already been updated (typically by block.apply_update inside the SR/MCMC driver), and interprets that block according to the structure of the geminal (lambda) matrix.

Responsibilities of this method are:

  • Map the block name (currently "lambda_matrix") to the internal geminal parameters.

  • Handle the splitting of a rectangular lambda matrix into paired and unpaired parts when needed.

  • Enforce Geminal-specific structural constraints, especially the symmetry conditions on the paired block of the lambda matrix.

All details about how the lambda parameters are stored and constrained live here (and in the surrounding Geminal_data class), not in VariationalParameterBlock or in the optimizer. This keeps the SR/MCMC machinery and the block abstraction structure-agnostic: adding new Geminal parameters should only require updating the block construction in Wavefunction_data.get_variational_blocks and adding the corresponding handling in this method.

Parameters:

block (VariationalParameterBlock)

Return type:

Geminal_data

collect_param_grads(grad_geminal)#

Collect parameter gradients into a flat dict keyed by block name.

Parameters:

grad_geminal (Geminal_data)

Return type:

dict[str, object]

property compute_orb_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#

Function for computing AOs or MOs.

The api method to compute AOs or MOs corresponding to instances stored in self.orb_data_up_spin and self.orb_data_dn_spin

Returns:

The api method to compute AOs or MOs.

Return type:

Callable

Raises:

NotImplementedError – If the instances of orb_data_up_spin/orb_data_dn_spin are neither AOs_data/AOs_data nor MOs_data/MOs_data.

property compute_orb_grad_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#

Function for computing AOs or MOs grads.

The api method to compute AOs or MOs grads corresponding to instances stored in self.orb_data_up_spin and self.orb_data_dn_spin.

Returns:

The api method to compute AOs or MOs grads.

Return type:

Callable

Raises:

NotImplementedError – If the instances of orb_data_up_spin/orb_data_dn_spin are neither AOs_data/AOs_data nor MOs_data/MOs_data.

property compute_orb_laplacian_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#

Function for computing AOs or MOs laplacians.

The api method to compute AOs or MOs laplacians corresponding to instances stored in self.orb_data_up_spin and self.orb_data_dn_spin.

Returns:

The api method to compute AOs or MOs laplacians.

Return type:

Callable

Raises:

NotImplementedError – If the instances of orb_data_up_spin/orb_data_dn_spin are neither AOs_data/AOs_data nor MOs_data/MOs_data.

classmethod convert_from_MOs_to_AOs(geminal_data)#

Convert MOs to AOs.

Parameters:

geminal_data (Geminal_data)

Return type:

Geminal_data

lambda_matrix: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array#

Geminal pairing matrix; see class notes for expected shape.

num_electron_dn: int = 0#

Number of spin-down electrons.

num_electron_up: int = 0#

Number of spin-up electrons.

orb_data_dn_spin: AOs_sphe_data | AOs_cart_data | MOs_data#

Orbital data (AOs or MOs) for spin-down electrons.

orb_data_up_spin: AOs_sphe_data | AOs_cart_data | MOs_data#

Orbital data (AOs or MOs) for spin-up electrons.

property orb_num_dn: int#

orb_num_dn.

The number of atomic orbitals or molecular orbitals for down electrons, depending on the instance stored in the attribute orb_data_up.

Returns:

The number of atomic orbitals or molecular orbitals for down electrons.

Return type:

int

Raises:

NotImplementedError – If the instance of orb_data_dn_spin is neither AOs_data nor MOs_data.

property orb_num_up: int#

orb_num_up.

The number of atomic orbitals or molecular orbitals for up electrons, depending on the instance stored in the attribute orb_data_up.

Returns:

The number of atomic orbitals or molecular orbitals for up electrons.

Return type:

int

Raises:

NotImplementedError – If the instance of orb_data_up_spin is neither AOs_data nor MOs_data.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

jqmc.determinant.compute_AS_regularization_factor(geminal_data, r_up_carts, r_dn_carts)#

Compute Attaccalite–Sorella regularization from electron coordinates.

Parameters:
  • geminal_data (Geminal_data) – Geminal parameters and orbital references.

  • r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape (N_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape (N_dn, 3).

Returns:

Scalar AS regularization factor.

Return type:

jax.Array

jqmc.determinant.compute_AS_regularization_factor_fast_update(geminal, geminal_inv)#

Compute Attaccalite–Sorella regularization via fast update.

Parameters:
  • geminal (ndarray[tuple[Any, ...], dtype[float64]]) – Geminal matrix with shape (N_up, N_up).

  • geminal_inv (ndarray[tuple[Any, ...], dtype[float64]]) – Inverse geminal matrix with shape (N_up, N_up).

Returns:

Scalar AS regularization factor.

Return type:

jax.Array

jqmc.determinant.compute_det_geminal_all_elements(geminal_data, r_up_carts, r_dn_carts)#

Compute $det G$ for the geminal matrix.

Parameters:
  • geminal_data (Geminal_data) – Geminal parameters and orbital references.

  • r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape (N_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape (N_dn, 3).

Returns:

Scalar determinant of the geminal matrix.

Return type:

float

jqmc.determinant.compute_geminal_all_elements(geminal_data, r_up_carts, r_dn_carts)#

Compute geminal matrix $G$ for all electron pairs.

Parameters:
  • geminal_data (Geminal_data) – Geminal parameters and orbital references.

  • r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape (N_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape (N_dn, 3).

Returns:

Geminal matrix with shape (N_up, N_up) combining paired and unpaired blocks.

Return type:

jax.Array

jqmc.determinant.compute_geminal_dn_one_column_elements(geminal_data, r_up_carts, r_dn_cart)#

Single column of the geminal matrix for one spin-down electron.

Parameters:
  • geminal_data – Geminal parameters and orbital references.

  • r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape (N_up, 3).

  • r_dn_cart (Array) – Cartesian coordinate for one spin-down electron with shape (3,) or (1, 3).

Returns:

Column vector for the paired block with shape (N_up,).

Return type:

jax.Array

jqmc.determinant.compute_geminal_up_one_row_elements(geminal_data, r_up_cart, r_dn_carts)#

Single row of the geminal matrix for one spin-up electron.

Parameters:
  • geminal_data – Geminal parameters and orbital references.

  • r_up_cart (Array) – Cartesian coordinate for one spin-up electron with shape (3,) or (1, 3).

  • r_dn_carts (Array) – Cartesian coordinates for all spin-down electrons with shape (N_dn, 3).

Returns:

Row vector with shape (N_dn + N_unpaired,).

Return type:

jax.Array

jqmc.determinant.compute_grads_and_laplacian_ln_Det(geminal_data, r_up_carts, r_dn_carts)#

Gradients and Laplacians of $lndet G$ for each electron.

Parameters:
  • geminal_data (Geminal_data) – Geminal parameters and orbital references.

  • r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape (N_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape (N_dn, 3).

Returns:

  • Gradients for spin-up electrons with shape (N_up, 3).

  • Gradients for spin-down electrons with shape (N_dn, 3).

  • Laplacians for spin-up electrons with shape (N_up,).

  • Laplacians for spin-down electrons with shape (N_dn,).

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array]

jqmc.determinant.compute_grads_and_laplacian_ln_Det_fast(geminal_data, r_up_carts, r_dn_carts, geminal_inverse)#

Gradients and Laplacians of ln det G using a precomputed inverse.

Parameters:
  • geminal_data (Geminal_data) – Geminal parameters and orbital references.

  • r_up_carts (Array) – Cartesian coordinates of spin-up electrons with shape (N_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of spin-down electrons with shape (N_dn, 3).

  • geminal_inverse (Array) – Precomputed inverse of the geminal matrix G.

Returns:

Gradients (up/down) and Laplacians (up/down) of ln det G per electron.

Return type:

tuple[Array, Array, Array, Array]

jqmc.determinant.compute_ratio_determinant_part(geminal_data, A_old_inv, old_r_up_carts, old_r_dn_carts, new_r_up_carts_arr, new_r_dn_carts_arr)#

Determinant ratio $det G(mathbf r’)/det G(mathbf r)$ for batched moves.

Parameters:
  • geminal_data (Geminal_data) – Geminal parameters and orbital references.

  • A_old_inv (Array) – Inverse geminal matrix for the reference configuration with shape (N_up, N_up).

  • old_r_up_carts (Array) – Original spin-up electron coordinates with shape (N_up, 3).

  • old_r_dn_carts (Array) – Original spin-down electron coordinates with shape (N_dn, 3).

  • new_r_up_carts_arr (Array) – Proposed spin-up coordinates per grid with shape (N_grid, N_up, 3).

  • new_r_dn_carts_arr (Array) – Proposed spin-down coordinates per grid with shape (N_grid, N_dn, 3).

Returns:

Determinant ratios per grid with shape (N_grid,).

Return type:

jax.Array

jqmc.diff_mask module#

Utilities for masking PyTree leaves by derivative type.

This module centralizes the logic for selectively stopping gradients on parameter- or coordinate-like leaves so we can reuse the same dataclasses for all derivative modes (params, positions, or none) without proliferating specialized subclasses.

class jqmc.diff_mask.DiffMask(params=True, coords=True)#

Bases: object

Simple mask controlling which leaf types remain differentiable.

Parameters:
  • params (bool)

  • coords (bool)

coords: bool = True#
params: bool = True#
update(*, params=None, coords=None)#

Return a new mask with any provided overrides applied.

Parameters:
  • params (bool | None)

  • coords (bool | None)

Return type:

DiffMask

jqmc.diff_mask.apply_diff_mask(obj, mask)#

Return a copy of obj with gradients stopped according to mask.

The function recurses through dataclass fields (including Flax struct dataclasses), honoring optional diff_tag metadata on fields when present. Otherwise, a small set of field-name heuristics is used to decide whether a leaf should be treated as a parameter (“param”) or coordinate (“coord”). Lists/tuples are traversed elementwise; everything else is returned unchanged.

Parameters:
Return type:

Any

jqmc.function_collections module#

Collections of useful functions.

jqmc.hamiltonians module#

Hamiltonian module.

class jqmc.hamiltonians.Hamiltonian_data(structure_data=<factory>, coulomb_potential_data=<factory>, wavefunction_data=<factory>)#

Bases: object

Hamiltonian dataclass.

The class contains data for computing Kinetic and Potential energy terms.

Parameters:
  • structure_data (Structure_data) – an instance of Structure_data

  • coulomb_data (Coulomb_data) – an instance of Coulomb_data

  • wavefunction_data (Wavefunction_data) – an instance of Wavefunction_data

  • coulomb_potential_data (Coulomb_potential_data)

Notes

Heres are the differentiable arguments, i.e., pytree_node = True This information is a little bit tricky in terms of a principle of the object-oriented programming, ‘Don’t ask, but tell’ (i.e., the Hamiltonian_data knows the details of the other classes too much), but there is no other choice to dynamically switch on and off pytree_nodes depending on optimized variational parameters chosen by a user because @dataclass is statistically generated.

WF parameters related:
  • lambda in wavefunction_data.geminal_data (determinant.py)

  • jastrow_2b_param in wavefunction_data.jastrow_data.jastrow_two_body_data (jastrow_factor.py)

  • j_matrix in wavefunction_data.jastrow_data.jastrow_three_body_data (jastrow_factor.py)

Atomic positions related:
  • positions in hamiltonian_data.structure_data (this file)

  • positions in wavefunction_data.geminal_data.mos_data/aos_data.structure_data (molecular_orbital.py/atomic_orbital.py)

  • positions in wavefunction_data.jastrow_data.jastrow_three_body_data.mos_data/aos_data.structure_data (jastrow_factor.py)

  • positions in Coulomb_potential_data.structure_data (coulomb_potential.py)

accumulate_position_grad(grad_hamiltonian)#

Aggregate position gradients from Hamiltonian components (structure + wavefunction).

Parameters:

grad_hamiltonian (Hamiltonian_data)

coulomb_potential_data: Coulomb_potential_data#
static load_from_hdf5(filepath='jqmc.h5')#

Load Hamiltonian data from an HDF5 file.

Parameters:

filepath (str, optional) – file path

Returns:

An instance of Hamiltonian_data.

Return type:

Hamiltonian_data

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

save_to_hdf5(filepath='jqmc.h5')#

Save Hamiltonian data to an HDF5 file.

Parameters:

filepath (str, optional) – file path

Return type:

None

structure_data: Structure_data#
wavefunction_data: Wavefunction_data#
jqmc.hamiltonians.compute_local_energy(hamiltonian_data, r_up_carts, r_dn_carts, RT)#

Compute Local Energy.

The method is for computing the local energy at (r_up_carts, r_dn_carts).

Parameters:
  • hamiltonian_data (Hamiltonian_data) – an instance of Hamiltonian_data

  • r_up_carts (jnpt.ArrayLike) – Cartesian coordinates of up-spin electrons (dim: N_e^{up}, 3)

  • r_dn_carts (jnpt.ArrayLike) – Cartesian coordinates of dn-spin electrons (dim: N_e^{dn}, 3)

  • RT (jnpt.ArrayLike) – Rotation matrix. equiv R.T used for non-local part. It does not affect all-electron calculations.

Returns:

The value of local energy (e_L) with the given wavefunction (float)

Return type:

float

jqmc.jastrow_factor module#

Jastrow module.

class jqmc.jastrow_factor.Jastrow_NN_data(nn_def=None, params=None, flat_shape=(), num_params=0, treedef=None, shapes=<factory>, hidden_dim=64, num_layers=3, num_rbf=16, cutoff=5.0, num_species=0, species_lookup=(0, ), species_values=<factory>, structure_data=None)#

Bases: object

Container for NN-based Jastrow factor.

This dataclass stores both the neural network definition and its parameters, together with helper functions that integrate the NN Jastrow term into the variational-parameter block machinery.

The intended usage is:

  • nn_def holds a Flax/SchNet-like module (e.g. NNJastrow).

  • params holds the corresponding PyTree of parameters.

  • flatten_fn / unflatten_fn convert between the PyTree and a

    1D parameter vector for SR/MCMC.

  • If this dataclass is set to None inside Jastrow_data,

    the NN contribution is simply turned off. If it is not None, its contribution is evaluated and added on top of the analytic three-body Jastrow (if present).

Parameters:
  • nn_def (Any)

  • params (Any)

  • flat_shape (tuple[int, ...])

  • num_params (int)

  • treedef (Any)

  • shapes (list[tuple[int, ...]])

  • hidden_dim (int)

  • num_layers (int)

  • num_rbf (int)

  • cutoff (float)

  • num_species (int)

  • species_lookup (tuple[int, ...])

  • species_values (tuple[int, ...])

  • structure_data (Structure_data | None)

cutoff: float = 5.0#

Radial cutoff for features.

flat_shape: tuple[int, ...] = ()#

Shape of flattened params.

property flatten_fn: Callable[[Any], Array]#

Return a flatten function built from treedef.

This is constructed on each access and is not part of the serialized state (so it will not cause pickle errors).

hidden_dim: int = 64#

Hidden width used in NNJastrow.

classmethod init_from_structure(structure_data, hidden_dim=64, num_layers=3, num_rbf=16, cutoff=5.0, key=None)#

Initialize NN Jastrow from structure information.

This creates a PauliNet-style NNJastrow module, initializes its parameters with a dummy electron configuration, and prepares flatten/unflatten utilities for SR/MCMC.

Parameters:
  • structure_data (Structure_data)

  • hidden_dim (int)

  • num_layers (int)

  • num_rbf (int)

  • cutoff (float)

Return type:

Jastrow_NN_data

nn_def: Any = None#

Flax module definition (e.g., NNJastrow).

num_layers: int = 3#

Number of PauliNet blocks.

num_params: int = 0#

Total number of parameters.

num_rbf: int = 16#

PhysNet radial basis size.

num_species: int = 0#

Count of unique nuclear species.

params: Any = None#

Parameter PyTree for nn_def.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

shapes: list[tuple[int, ...]]#

Per-leaf shapes.

species_lookup: tuple[int, ...] = (0,)#

Lookup table mapping Z to species ids.

species_values: tuple[int, ...]#

Sorted unique atomic numbers used.

structure_data: Structure_data | None = None#

Structure info required for NN evaluation.

treedef: Any = None#

PyTree treedef for params.

property unflatten_fn: Callable[[Array], Any]#

Return an unflatten function built from treedef and shapes.

As with flatten_fn(), this is constructed on each access and not stored inside the pickled state.

class jqmc.jastrow_factor.Jastrow_data(jastrow_one_body_data=None, jastrow_two_body_data=None, jastrow_three_body_data=None, jastrow_nn_data=None)#

Bases: object

Jastrow dataclass.

The class contains data for evaluating a Jastrow function.

Parameters:
  • jastrow_one_body_data (Jastrow_one_body_data) – An instance of Jastrow_one_body_data. If None, the one-body Jastrow is turned off.

  • jastrow_two_body_data (Jastrow_two_body_data) – An instance of Jastrow_two_body_data. If None, the two-body Jastrow is turned off.

  • jastrow_three_body_data (Jastrow_three_body_data) – An instance of Jastrow_three_body_data. if None, the three-body Jastrow is turned off.

  • jastrow_nn_data (Jastrow_NN_data | None) – Optional container for a NN-based three-body Jastrow term. If None, the Jastrow NN contribution is turned off.

accumulate_position_grad(grad_jastrow)#

Aggregate position gradients from all active Jastrow components.

Parameters:

grad_jastrow (Jastrow_data)

apply_block_update(block)#

Apply a single variational-parameter block update to this Jastrow object.

This method is the Jastrow-specific counterpart of Wavefunction_data.apply_block_updates(). It receives a generic VariationalParameterBlock whose values have already been updated (typically by block.apply_update inside the SR/MCMC driver), and interprets that block according to Jastrow semantics.

Responsibilities of this method are:

  • Map the block name (e.g. "j1_param", "j2_param", "j3_matrix") to the corresponding internal Jastrow field(s).

  • Enforce Jastrow-specific structural constraints when copying the block values into the internal arrays. In particular, for the three-body Jastrow term (J3) this includes:

    • Handling the case where only the last column is variational and the rest of the matrix is constrained.

    • Handling the fully square J3 matrix case.

    • Enforcing the required symmetry of the square J3 block.

By keeping all J1/J2/J3 interpretation and constraints in this method (and in the surrounding Jastrow_data class), the optimizer and VariationalParameterBlock remain completely structure-agnostic. To introduce a new Jastrow parameter, extend the block construction in Wavefunction_data.get_variational_blocks and add the corresponding handling here, without touching the SR/MCMC driver.

Parameters:

block (VariationalParameterBlock)

Return type:

Jastrow_data

collect_param_grads(grad_jastrow)#

Collect parameter gradients into a flat dict keyed by block name.

Parameters:

grad_jastrow (Jastrow_data)

Return type:

dict[str, object]

jastrow_nn_data: Jastrow_NN_data | None = None#

Optional NN-based three-body Jastrow component.

jastrow_one_body_data: Jastrow_one_body_data | None = None#

Optional one-body Jastrow component.

jastrow_three_body_data: Jastrow_three_body_data | None = None#

Optional analytic three-body Jastrow component.

jastrow_two_body_data: Jastrow_two_body_data | None = None#

Optional two-body Jastrow component.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

class jqmc.jastrow_factor.Jastrow_one_body_data(jastrow_1b_param=1.0, structure_data=<factory>, core_electrons=<factory>)#

Bases: object

One-body Jastrow parameters and structure metadata.

The one-body term models electron–nucleus correlations using the exponential form described in the original docstring. The numerical value is returned without the exp wrapper; callers attach exp(J) to the wavefunction.

Parameters:
  • jastrow_1b_param (float) – Parameter controlling the one-body decay.

  • structure_data (Structure_data) – Nuclear positions and charges.

  • core_electrons (tuple[float]) – Removed core electrons per nucleus (for ECPs).

core_electrons: list[float] | tuple[float]#

Effective core-electron counts aligned with structure_data.

classmethod init_jastrow_one_body_data(jastrow_1b_param, structure_data, core_electrons)#

Initialization.

jastrow_1b_param: float = 1.0#

One-body Jastrow exponent parameter.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

structure_data: Structure_data#

Nuclear structure data providing positions and atomic numbers.

class jqmc.jastrow_factor.Jastrow_three_body_data(orb_data=<factory>, j_matrix=<factory>)#

Bases: object

Three-body Jastrow parameters and orbital references.

The three-body term uses the original matrix layout (square J3 block plus last-column J1-like vector). Values are returned without exponentiation; callers attach exp(J) to the wavefunction. All existing functional details from the prior docstring are preserved.

Parameters:
  • orb_data (AOs_sphe_data | AOs_cart_data | MOs_data) – Basis/orbital data used for both spins.

  • j_matrix (npt.NDArray | jax.Array) – J matrix with shape (orb_num, orb_num + 1).

property compute_orb_api: Callable[[...], ndarray[tuple[Any, ...], dtype[float64]]]#

Function for computing AOs or MOs.

The api method to compute AOs or MOs corresponding to instances stored in self.orb_data

Returns:

The api method to compute AOs or MOs.

Return type:

Callable

Raises:

NotImplementedError – If the instances of orb_data is neither AOs_data nor MOs_data.

classmethod init_jastrow_three_body_data(orb_data, random_init=False, random_scale=0.01, seed=None)#

Initialization.

Parameters:
  • orb_data (AOs_sphe_data | AOs_cart_data | MOs_data) – Orbital container (AOs or MOs) used to size the J-matrix.

  • random_init (bool) – If True, initialize with small random values instead of zeros (for tests).

  • random_scale (float) – Upper bound of uniform sampler when random_init is True (default 0.01).

  • seed (int | None) – Optional seed for deterministic initialization when random_init is True.

j_matrix: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array#

J3/J1 matrix; square block plus final column.

orb_data: AOs_sphe_data | AOs_cart_data | MOs_data#

Orbital basis (AOs or MOs) shared across spins.

property orb_num: int#

Get number of atomic orbitals.

Returns:

get number of atomic orbitals.

Return type:

int

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

class jqmc.jastrow_factor.Jastrow_two_body_data(jastrow_2b_param=1.0)#

Bases: object

Two-body Jastrow parameter container.

The two-body term uses the Pade functional form described in the existing docstrings. Values are returned without exponentiation; callers use exp(J) when constructing the wavefunction.

Parameters:

jastrow_2b_param (float) – Parameter for the two-body Jastrow part.

classmethod init_jastrow_two_body_data(jastrow_2b_param=1.0)#

Initialization.

jastrow_2b_param: float = 1.0#

Pade a parameter for J2.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

class jqmc.jastrow_factor.NNJastrow(hidden_dim=64, num_layers=3, num_rbf=32, cutoff=5.0, species_lookup=None, num_species=None, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

PauliNet-inspired NN that outputs a three-body Jastrow correction.

The network implements the iteration rules described in the PauliNet manuscript (Eq. 1–2). Electron embeddings \(\mathbf{x}_i^{(n)}\) are iteratively refined by three message channels:

  • (+ ): same-spin electrons, enforcing antisymmetry indirectly by keeping

    the messages exchange-equivariant.

  • (- ): opposite-spin electrons, capturing pairing terms.

  • (n): nuclei, represented by fixed species embeddings.

After num_layers iterations the final electron embeddings are summed and fed through \(\eta_\theta\) to produce a symmetric correction that is added on top of the analytic three-body Jastrow.

Parameters:
  • hidden_dim (int)

  • num_layers (int)

  • num_rbf (int)

  • cutoff (float)

  • species_lookup (ndarray[tuple[Any, ...], dtype[int32]] | Array | tuple[int, ...] | None)

  • num_species (int | None)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

class PauliNetBlock(hidden_dim, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Single PauliNet message-passing iteration following Eq. (1).

Each block mixes three message channels per electron: same-spin (+ ), opposite-spin (- ), and nucleus-electron (n). The sender network is shared across channels to match the PauliNet weight-tying scheme, while separate weighting/receiver networks parameterize the contribution of every channel.

Parameters:
  • hidden_dim (int)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

hidden_dim: int#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()#

Instantiate the shared sender/receiver networks for this block.

class PhysNetRadialLayer(num_rbf, cutoff, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Cuspless PhysNet-inspired radial features \(e_k(r)\).

The basis follows Eq. (3) in the PauliNet supplement with a PhysNet-style envelope that forces both the value and the derivative of each Gaussian to vanish at the cutoff and the origin. These features are reused across all message channels, ensuring consistent geometric encoding.

Parameters:
  • num_rbf (int)

  • cutoff (float)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

cutoff: float#
name: str | None = None#
num_rbf: int#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class TwoLayerMLP(width, out_dim, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Utility MLP used for \(w_\theta\), \(h_\theta\), \(g_\theta\), and \(\eta_\theta\).

Parameters:
  • width (int)

  • out_dim (int)

  • parent (Module | Scope | _Sentinel | None)

  • name (str | None)

name: str | None = None#
out_dim: int#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
width: int#
cutoff: float = 5.0#
hidden_dim: int = 64#
name: str | None = None#
num_layers: int = 3#
num_rbf: int = 32#
num_species: int | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()#

Instantiate PauliNet components and validate required metadata.

Raises:

ValueError – If species_lookup or num_species were not provided via the host dataclass before module initialization.

species_lookup: ndarray[tuple[Any, ...], dtype[int32]] | Array | tuple[int, ...] | None = None#
jqmc.jastrow_factor.compute_Jastrow_one_body(jastrow_one_body_data, r_up_carts, r_dn_carts)#

Evaluate the one-body Jastrow $J_1$ (without exp) for given coordinates.

The original exponential form and usage remain unchanged: this routine returns the scalar J value; callers attach exp(J) to the wavefunction.

Parameters:
  • jastrow_one_body_data (Jastrow_one_body_data) – One-body Jastrow parameters and structure data.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

One-body Jastrow value (before exponentiation).

Return type:

float

jqmc.jastrow_factor.compute_Jastrow_part(jastrow_data, r_up_carts, r_dn_carts)#

Evaluate the total Jastrow J = J1 + J2 + J3 (without exponentiation).

This preserves the original behavior: the returned scalar J excludes the exp factor; callers apply exp(J) to the wavefunction. Both the analytic three-body and optional NN three-body contributions are included.

Parameters:
  • jastrow_data (Jastrow_data) – Collection of active Jastrow components (J1/J2/J3/NN).

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Total Jastrow value before exponentiation.

Return type:

float

jqmc.jastrow_factor.compute_Jastrow_three_body(jastrow_three_body_data, r_up_carts, r_dn_carts)#

Evaluate the three-body Jastrow $J_3$ (analytic) without exponentiation.

This preserves the original functional form: the square J3 block couples electron pairs and the last column acts as a J1-like vector. Returned value is J; attach exp(J) externally.

Parameters:
  • jastrow_three_body_data (Jastrow_three_body_data) – Three-body Jastrow parameters and orbitals.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Three-body Jastrow value (before exponentiation).

Return type:

float

jqmc.jastrow_factor.compute_Jastrow_two_body(jastrow_two_body_data, r_up_carts, r_dn_carts)#

Evaluate the two-body Jastrow $J_2$ (Pade form) without exponentiation.

The functional form and usage remain identical to the original docstring; this returns J and callers attach exp(J) to the wavefunction.

Parameters:
  • jastrow_two_body_data (Jastrow_two_body_data) – Two-body Jastrow parameter container.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Two-body Jastrow value (before exponentiation).

Return type:

float

jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_one_body(jastrow_one_body_data, r_up_carts, r_dn_carts)#

Analytic gradients and per-electron Laplacians for the one-body Jastrow.

Parameters:
  • jastrow_one_body_data (Jastrow_one_body_data) – One-body Jastrow parameters and structure data.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Gradients for up/down electrons with shapes (N_up, 3) and (N_dn, 3), Laplacians for up/down electrons with shapes (N_up,) and (N_dn,).

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array]

jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_part(jastrow_data, r_up_carts, r_dn_carts)#

Per-electron gradients and Laplacians of the full Jastrow $J$.

Analytic paths are used for J1/J2/J3 when available; the NN three-body term (if present) is handled via autodiff. Values are returned per electron (not summed) to match downstream kinetic-energy estimators.

Parameters:
  • jastrow_data (Jastrow_data) – Active Jastrow components.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Gradients for up/down electrons with shapes (N_up, 3) and (N_dn, 3) and Laplacians for up/down electrons with shapes (N_up,) and (N_dn,).

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array]

jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_three_body(jastrow_three_body_data, r_up_carts, r_dn_carts)#

Analytic gradients and Laplacians for the three-body Jastrow.

The functional form is unchanged; this routine leverages analytic AO/MO gradients and Laplacians. Per-electron derivatives are returned (not summed), matching kinetic-energy estimator expectations.

Parameters:
  • jastrow_three_body_data (Jastrow_three_body_data) – Three-body Jastrow parameters and orbitals.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Gradients for up/down electrons with shapes (N_up, 3) and (N_dn, 3), Laplacians for up/down electrons with shapes (N_up,) and (N_dn,).

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array]

jqmc.jastrow_factor.compute_grads_and_laplacian_Jastrow_two_body(jastrow_two_body_data, r_up_carts, r_dn_carts)#

Analytic gradients and Laplacians for the Pade two-body Jastrow.

Uses the unchanged functional form J2(r) = r / (2 * (1 + a r)) with a = jastrow_2b_param. Returns per-electron quantities (not summed).

Parameters:
  • jastrow_two_body_data (Jastrow_two_body_data) – Two-body Jastrow parameter container.

  • r_up_carts (Array) – Spin-up electron coordinates with shape (N_up, 3).

  • r_dn_carts (Array) – Spin-down electron coordinates with shape (N_dn, 3).

Returns:

Gradients for up/down electrons with shapes (N_up, 3) and (N_dn, 3), Laplacians for up/down electrons with shapes (N_up,) and (N_dn,).

Return type:

tuple[jax.Array, jax.Array, jax.Array, jax.Array]

jqmc.jastrow_factor.compute_ratio_Jastrow_part(jastrow_data, old_r_up_carts, old_r_dn_carts, new_r_up_carts_arr, new_r_dn_carts_arr)#

Compute $exp(J(mathbf r’))/exp(J(mathbf r))$ for batched moves.

This follows the original ratio logic (including exp) while updating types to use jax.Array inputs. The return is one ratio per proposed grid configuration.

Parameters:
  • jastrow_data (Jastrow_data) – Active Jastrow components.

  • old_r_up_carts (Array) – Reference spin-up coordinates with shape (N_up, 3).

  • old_r_dn_carts (Array) – Reference spin-down coordinates with shape (N_dn, 3).

  • new_r_up_carts_arr (Array) – Proposed spin-up coordinates with shape (N_grid, N_up, 3).

  • new_r_dn_carts_arr (Array) – Proposed spin-down coordinates with shape (N_grid, N_dn, 3).

Returns:

Jastrow ratios per grid with shape (N_grid,) (includes exp).

Return type:

jax.Array

jqmc.jqmc_cli module#

command-line module.

jqmc.jqmc_gfmc module#

QMC module.

class jqmc.jqmc_gfmc.GFMC_n(hamiltonian_data=None, num_walkers=40, num_mcmc_per_measurement=16, num_gfmc_collect_steps=5, mcmc_seed=34467, E_scf=0.0, alat=0.1, random_discretized_mesh=True, non_local_move='tmove', comput_position_deriv=False)#

Bases: object

GFMC class. Runing GFMC with multiple walkers.

Parameters:
  • hamiltonian_data (Hamiltonian_data) – an instance of Hamiltonian_data

  • num_walkers (int) – the number of walkers

  • mcmc_seed (int) – seed for the MCMC chain.

  • E_scf (float) – Self-consistent E (Hartree)

  • alat (float) – discretized grid length (bohr)

  • random_discretized_mesh (bool) – Flag for the random discretization mesh in the kinetic part and the non-local part of ECPs. Valid both for all-electron and ECP calculations.

  • non_local_move (str) – treatment of the spin-flip term. tmove (Casula’s T-move) or dtmove (Determinant Locality Approximation with Casula’s T-move) Valid only for ECP calculations. Do not specify this value for all-electron calculations.

  • comput_position_deriv (bool) – if True, compute the derivatives of E wrt. atomic positions.

  • num_mcmc_per_measurement (int)

  • num_gfmc_collect_steps (int)

property alat#

Return alat.

property bare_w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored weight array. dim

property comput_position_deriv: bool#

Return the flag for computing the derivatives of E wrt. atomic positions.

property de_L_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored de_L/dR array. dim

property de_L_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_electrons_dn, 3).

Type:

Return the stored de_L/dr_dn array. dim

property de_L_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_electrons_up, 3).

Type:

Return the stored de_L/dr_up array. dim

property dln_Psi_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_atoms, 3).

Type:

Return the stored dln_Psi/dR array. dim

property dln_Psi_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_electrons_dn, 3).

Type:

Return the stored dln_Psi/dr_down array. dim

property dln_Psi_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_electrons_up, 3).

Type:

Return the stored dln_Psi/dr_up array. dim

property domega_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_electons_dn, 3).

Type:

Return the stored dOmega/dr_dn array. dim

property domega_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_electons_dn, 3).

Type:

Return the stored dOmega/dr_up array. dim

property e_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored e_L array. dim

property e_L2: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored e_L^2 array. dim

get_E(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#

Return the mean and std of the computed local energy.

Parameters:
  • num_mcmc_warmup_steps (int) – the number of warmup steps.

  • num_mcmc_bin_blocks (int) – the number of binning blocks

Returns:

The mean and std values of the totat energy and those of the variance estimated by the Jackknife method with the Args. (E_mean, E_std, Var_mean, Var_std).

Return type:

tuple[float, float, float, float]

get_aF(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#

Return the mean and std of the computed atomic forces.

Parameters:
  • num_mcmc_warmup_steps (int) – the number of warmup steps.

  • num_mcmc_bin_blocks (int) – the number of binning blocks

Returns:

The mean and std values of the computed atomic forces estimated by the Jackknife method with the Args. The dimention of the arrays is (N, 3).

Return type:

tuple[npt.NDArray, npt.NDArray]

property hamiltonian_data#

Return hamiltonian_data.

property mcmc_counter: int#

Return current MCMC counter.

property num_gfmc_collect_steps#

Return num_gfmc_collect_steps.

property num_walkers#

The number of walkers.

property omega_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter,1, num_atoms, num_electons_dn).

Type:

Return the stored Omega (for down electrons) array. dim

property omega_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1, num_atoms, num_electrons_up).

Type:

Return the stored Omega (for up electrons) array. dim

run(num_mcmc_steps=50, max_time=86400)#

Run LRDMC with multiple walkers.

Parameters:
  • num_branching (int) – number of branching (reconfiguration of walkers).

  • max_time (int) – maximum time in sec.

  • num_mcmc_steps (int)

Return type:

None

property w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored weight array. dim

class jqmc.jqmc_gfmc.GFMC_t(hamiltonian_data=None, num_walkers=40, num_gfmc_collect_steps=5, mcmc_seed=34467, tau=0.1, alat=0.1, random_discretized_mesh=True, non_local_move='tmove')#

Bases: object

GFMC class.

GFMC class. Runing GFMC.

Parameters:
  • hamiltonian_data (Hamiltonian_data) – an instance of Hamiltonian_data

  • num_walkers (int) – the number of walkers

  • num_gfmc_collect_steps (int) – the number of steps to collect the GFMC data

  • mcmc_seed (int) – seed for the MCMC chain.

  • tau (float) – projection time (bohr^-1)

  • alat (float) – discretized grid length (bohr)

  • random_discretized_mesh (bool) – Flag for the random discretization mesh in the kinetic part and in the non-local part of ECPs. Valid both for all-electron and ECP calculations.

  • non_local_move (str) – treatment of the spin-flip term. tmove (Casula’s T-move) or dtmove (Determinant Locality Approximation with Casula’s T-move) Valid only for ECP calculations. All-electron calculations, do not specify this value.

property alat#

Return alat.

property bare_w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored weight array. dim

property e_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored e_L array. dim

property e_L2: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored e_L2 array. dim

get_E(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#

Return the mean and std of the computed local energy.

Parameters:
  • num_mcmc_warmup_steps (int) – the number of warmup steps.

  • num_mcmc_bin_blocks (int) – the number of binning blocks

Returns:

The mean and std values of the totat energy and those of the variance estimated by the Jackknife method with the Args. (E_mean, E_std, Var_mean, Var_std).

Return type:

tuple[float, float, float, float]

property hamiltonian_data#

Return hamiltonian_data.

property mcmc_counter: int#

Return current MCMC counter.

property num_gfmc_collect_steps#

Return num_gfmc_collect_steps.

property num_walkers#

The number of walkers.

run(num_mcmc_steps=50, max_time=86400)#

Run LRDMC with multiple walkers.

Parameters:
  • num_branching (int) – number of branching (reconfiguration of walkers).

  • max_time (int) – maximum time in sec.

  • num_mcmc_steps (int)

Return type:

None

property w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, 1).

Type:

Return the stored weight array. dim

jqmc.jqmc_mcmc module#

QMC module.

class jqmc.jqmc_mcmc.MCMC(hamiltonian_data=None, mcmc_seed=34467, num_walkers=40, num_mcmc_per_measurement=16, Dt=2.0, epsilon_AS=0.1, comput_param_deriv=False, comput_position_deriv=False, random_discretized_mesh=True)#

Bases: object

Production VMC/MCMC driver with multiple walkers.

This class drives Metropolis–Hastings sampling for many independent walkers in parallel (vectorized with jax.vmap) and stores all observables needed by downstream analysis and optimization. All public methods are part of the supported API; private helpers are internal and subject to change.

Parameters:
  • hamiltonian_data (Hamiltonian_data)

  • mcmc_seed (int)

  • num_walkers (int)

  • num_mcmc_per_measurement (int)

  • Dt (float)

  • epsilon_AS (float)

  • comput_param_deriv (bool)

  • comput_position_deriv (bool)

  • random_discretized_mesh (bool)

property comput_param_deriv: bool#

Return the flag for computing the derivatives of E wrt. variational parameters.

property comput_position_deriv: bool#

Return the flag for computing the derivatives of E wrt. atomic positions.

property de_L_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers).

Type:

Return the stored de_L/dR array. dim

property de_L_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_electrons_dn, 3).

Type:

Return the stored de_L/dr_dn array. dim

property de_L_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_electrons_up, 3).

Type:

Return the stored de_L/dr_up array. dim

property dln_Psi_dR: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_atoms, 3).

Type:

Return the stored dln_Psi/dR array. dim

property dln_Psi_dc: dict[str, ndarray[tuple[Any, ...], dtype[_ScalarT]]]#

Return stored parameter gradients keyed by block name.

property dln_Psi_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_electrons_dn, 3).

Type:

Return the stored dln_Psi/dr_down array. dim

property dln_Psi_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_electrons_up, 3).

Type:

Return the stored dln_Psi/dr_up array. dim

property domega_dr_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_electons_dn, 3).

Type:

Return the stored dOmega/dr_dn array. dim

property domega_dr_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_electons_dn, 3).

Type:

Return the stored dOmega/dr_up array. dim

property e_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers).

Type:

Return the stored e_L array. dim

property e_L2: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers).

Type:

Return the stored e_L^2 array. dim

get_E(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#

Estimate total energy and variance with jackknife error bars.

Parameters:
  • num_mcmc_warmup_steps (int, optional) – Samples to discard as warmup. Defaults to 50.

  • num_mcmc_bin_blocks (int, optional) – Number of jackknife blocks. Defaults to 10.

Returns:

(E_mean, E_std, Var_mean, Var_std) aggregated across MPI ranks.

Return type:

tuple[float, float, float, float]

Raises:

ValueError – If there are insufficient post-warmup samples to form the requested blocks.

Notes

Warns when warmup or block counts fall below MCMC_MIN_WARMUP_STEPS / MCMC_MIN_BIN_BLOCKS. All reductions are MPI-aware.

get_aF(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10)#

Compute Hellmann–Feynman + Pulay forces with jackknife statistics.

Parameters:
  • num_mcmc_warmup_steps (int, optional) – Samples to drop for warmup. Defaults to 50.

  • num_mcmc_bin_blocks (int, optional) – Number of jackknife blocks. Defaults to 10.

Returns:

(force_mean, force_std) shaped (num_atoms, 3) in Hartree/bohr.

Return type:

tuple[npt.NDArray, npt.NDArray]

Notes

Uses stored per-walker weights, energies, wavefunction gradients, and SWCT terms accumulated during run(); reductions are MPI-aware.

get_dln_WF(blocks, num_mcmc_warmup_steps=50, chosen_param_index=None)#

Assemble per-sample derivatives of ln Psi w.r.t. variational parameters.

Parameters:
  • blocks (list[VariationalParameterBlock]) – Ordered variational blocks used for concatenation.

  • num_mcmc_warmup_steps (int, optional) – Samples to discard as warmup. Defaults to 50.

  • chosen_param_index (list | None, optional) – Optional subset of flattened indices; None keeps all. Defaults to None.

Returns:

O_matrix with shape (M, num_walkers, K) after warmup, where K follows the provided blocks (or subset).

Return type:

npt.NDArray

Notes

Validates the concatenated gradient size against block metadata and uses gradients stored during run().

get_gF(num_mcmc_warmup_steps=50, num_mcmc_bin_blocks=10, chosen_param_index=None, blocks=None)#

Evaluate generalized forces (dE/dc_k) with jackknife error bars.

Parameters:
  • num_mcmc_warmup_steps (int, optional) – Samples to discard as warmup. Defaults to 50.

  • num_mcmc_bin_blocks (int, optional) – Number of jackknife blocks. Defaults to 10.

  • chosen_param_index (list | None, optional) – Optional subset of flattened indices. Defaults to None.

  • blocks (list | None, optional) – Variational blocks for parameter ordering; defaults to current wavefunction blocks.

Returns:

(generalized_force_mean, generalized_force_std) as 1D vectors of length L after any filtering.

Return type:

tuple[npt.NDArray, npt.NDArray]

Notes

Reuses get_dln_WF() after warmup and applies jackknife statistics across MPI ranks.

property hamiltonian_data#

Access the mutable Hamiltonian_data backing this sampler.

property mcmc_counter: int#

Number of Metropolis steps accumulated (rows in stored observables).

property num_walkers#

The number of walkers.

property omega_dn: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_atoms, num_electons_dn).

Type:

Return the stored Omega (for down electrons) array. dim

property omega_up: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers, num_atoms, num_electrons_up).

Type:

Return the stored Omega (for up electrons) array. dim

run(num_mcmc_steps=0, max_time=86400)#

Execute Metropolis–Hastings sampling for all walkers.

Parameters:
  • num_mcmc_steps (int, optional) – Metropolis updates per walker; values <= 0 are no-ops. Defaults to 0.

  • max_time (int, optional) – Wall-clock budget in seconds. Defaults to 86400.

Return type:

None

Notes

  • Creates external_control_mcmc.toml to allow external stop requests.

  • Accumulates energies, weights, forces, and wavefunction gradients into public buffers (w_L, e_L, dln_Psi_* etc.).

  • Logs timing statistics and acceptance ratios at the end of the run.

run_optimize(num_mcmc_steps=100, num_opt_steps=1, wf_dump_freq=10, max_time=86400, num_mcmc_warmup_steps=0, num_mcmc_bin_blocks=100, opt_J1_param=True, opt_J2_param=True, opt_J3_param=True, opt_JNN_param=True, opt_lambda_param=False, num_param_opt=0, optimizer_kwargs=None)#

Optimize wavefunction parameters using SR or optax.

Parameters:
  • num_mcmc_steps (int, optional) – MCMC samples per walker per iteration. Defaults to 100.

  • num_opt_steps (int, optional) – Number of optimization iterations. Defaults to 1.

  • wf_dump_freq (int, optional) – Dump frequency for hamiltonian_data.chk. Defaults to 10.

  • max_time (int, optional) – Per-iteration MCMC wall-clock budget (sec). Defaults to 86400.

  • num_mcmc_warmup_steps (int, optional) – Warmup samples discarded each iteration. Defaults to 0.

  • num_mcmc_bin_blocks (int, optional) – Jackknife bins for statistics. Defaults to 100.

  • opt_J1_param (bool, optional) – Optimize one-body Jastrow. Defaults to True.

  • opt_J2_param (bool, optional) – Optimize two-body Jastrow. Defaults to True.

  • opt_J3_param (bool, optional) – Optimize three-body Jastrow. Defaults to True.

  • opt_JNN_param (bool, optional) – Optimize NN Jastrow. Defaults to True.

  • opt_lambda_param (bool, optional) – Optimize determinant lambda matrix. Defaults to False.

  • num_param_opt (int, optional) – Limit parameters updated (ranked by |f|/|std f|); 0 means all. Defaults to 0.

  • optimizer_kwargs (dict | None, optional) – Optimizer configuration. method='sr' uses SR keys (delta, epsilon, cg_flag, cg_max_iter, cg_tol); other method names are optax constructors (e.g., "adam") and receive remaining keys.

Notes

  • Persists optax optimizer state across calls when method and hyperparameters match.

  • Writes external_control_opt.toml to allow external stop requests.

  • Updates Hamiltonian_data in-place and increments the optimization counter.

property w_L: ndarray[tuple[Any, ...], dtype[_ScalarT]]#

(mcmc_counter, num_walkers).

Type:

Return the stored weight array. dim

jqmc.jqmc_miscs module#

jQMC miscs.

jqmc.jqmc_tool module#

jQMC tools.

class jqmc.jqmc_tool.ansatz_type(value)#

Bases: str, Enum

Orbital type.

jagp = 'jagp'#
jsd = 'jsd'#
jqmc.jqmc_tool.hamiltonian_convert_wavefunction(hamiltonian_data_org_file=<typer.models.ArgumentInfo object>, convert_to=<typer.models.OptionInfo object>, hamiltonian_data_conv_file=<typer.models.OptionInfo object>)#

Convert wavefunction data in the Hamiltonian data.

Parameters:
  • hamiltonian_data_org_file (str)

  • convert_to (ansatz_type)

  • hamiltonian_data_conv_file (str)

jqmc.jqmc_tool.hamiltonian_show_info(hamiltonian_data=<typer.models.ArgumentInfo object>)#

Show information stored in the Hamiltonian data.

Parameters:

hamiltonian_data (str)

jqmc.jqmc_tool.hamiltonian_to_xyz(hamiltonian_data=<typer.models.ArgumentInfo object>, xyz_file=<typer.models.OptionInfo object>)#

Show information stored in the Hamiltonian data.

Parameters:
  • hamiltonian_data (str)

  • xyz_file (str)

jqmc.jqmc_tool.lrdmc_chk_fix(restart_chk=<typer.models.ArgumentInfo object>)#

LRDMC chk file fix.

Parameters:

restart_chk (str)

jqmc.jqmc_tool.lrdmc_compute_energy(restart_chk=<typer.models.ArgumentInfo object>, num_gfmc_bin_block=<typer.models.OptionInfo object>, num_gfmc_warmup_steps=<typer.models.OptionInfo object>, num_gfmc_collect_steps=<typer.models.OptionInfo object>)#

LRDMC energy calculation.

Parameters:
  • restart_chk (str)

  • num_gfmc_bin_block (int)

  • num_gfmc_warmup_steps (int)

  • num_gfmc_collect_steps (int)

jqmc.jqmc_tool.lrdmc_extrapolate_energy(restart_chks=<typer.models.ArgumentInfo object>, polynomial_order=<typer.models.OptionInfo object>, plot_graph=<typer.models.OptionInfo object>, save_graph=<typer.models.OptionInfo object>, num_gfmc_bin_block=<typer.models.OptionInfo object>, num_gfmc_warmup_steps=<typer.models.OptionInfo object>, num_gfmc_collect_steps=<typer.models.OptionInfo object>)#

LRDMC energy calculation.

Parameters:
  • restart_chks (List[str])

  • polynomial_order (int)

  • plot_graph (bool)

  • save_graph (str)

  • num_gfmc_bin_block (int)

  • num_gfmc_warmup_steps (int)

  • num_gfmc_collect_steps (int)

jqmc.jqmc_tool.lrdmc_generate_input(flag=<typer.models.OptionInfo object>, filename=<typer.models.OptionInfo object>, exclude_comment=<typer.models.OptionInfo object>)#

Generate an input file for LRDMC calculations.

Parameters:
  • flag (bool)

  • filename (str)

  • exclude_comment (bool)

jqmc.jqmc_tool.mcmc_chk_fix(restart_chk=<typer.models.ArgumentInfo object>)#

VMC chk file fix.

Parameters:

restart_chk (str)

jqmc.jqmc_tool.mcmc_compute_energy(restart_chk=<typer.models.ArgumentInfo object>, num_mcmc_bin_blocks=<typer.models.OptionInfo object>, num_mcmc_warmup_steps=<typer.models.OptionInfo object>)#

VMC energy calculation.

Parameters:
  • restart_chk (str)

  • num_mcmc_bin_blocks (int)

  • num_mcmc_warmup_steps (int)

jqmc.jqmc_tool.mcmc_generate_input(flag=<typer.models.OptionInfo object>, filename=<typer.models.OptionInfo object>, exclude_comment=<typer.models.OptionInfo object>)#

Generate an input file for VMC calculations.

Parameters:
  • flag (bool)

  • filename (str)

  • exclude_comment (bool)

class jqmc.jqmc_tool.orbital_type(value)#

Bases: str, Enum

Orbital type.

ao = 'ao'#
ao_full = 'ao-full'#
ao_large = 'ao-large'#
ao_medium = 'ao-medium'#
ao_small = 'ao-small'#
mo = 'mo'#
none = 'none'#
jqmc.jqmc_tool.trexio_convert_to(trexio_file=<typer.models.ArgumentInfo object>, hamiltonian_file=<typer.models.OptionInfo object>, j1_parmeter=<typer.models.OptionInfo object>, j2_parmeter=<typer.models.OptionInfo object>, j3_basis_type=<typer.models.OptionInfo object>, j_nn_type=<typer.models.OptionInfo object>, j_nn_params=<typer.models.OptionInfo object>)#

Convert a TREXIO file to hamiltonian_data.

Parameters:
  • trexio_file (str)

  • hamiltonian_file (str)

  • j1_parmeter (float)

  • j2_parmeter (float)

  • j3_basis_type (orbital_type)

  • j_nn_type (str)

  • j_nn_params (List[str])

jqmc.jqmc_tool.trexio_show_detail(filename=<typer.models.ArgumentInfo object>)#

Show information stored in the TREXIO file.

Parameters:

filename (str)

jqmc.jqmc_tool.trexio_show_info(filename=<typer.models.ArgumentInfo object>)#

Show information stored in the TREXIO file.

Parameters:

filename (str)

jqmc.jqmc_tool.vmc_analyze_output(filenames=<typer.models.ArgumentInfo object>, plot_graph=<typer.models.OptionInfo object>, save_graph=<typer.models.OptionInfo object>)#

Analyze the output files of vmc optimizations.

Parameters:
  • filenames (List[str])

  • plot_graph (bool)

  • save_graph (str)

jqmc.jqmc_tool.vmc_chk_fix(restart_chk=<typer.models.ArgumentInfo object>)#

VMCopt chk file fix.

Parameters:

restart_chk (str)

jqmc.jqmc_tool.vmc_generate_input(flag=<typer.models.OptionInfo object>, filename=<typer.models.OptionInfo object>, exclude_comment=<typer.models.OptionInfo object>)#

Generate an input file for VMCopt calculations.

Parameters:
  • flag (bool)

  • filename (str)

  • exclude_comment (bool)

jqmc.jqmc_utility module#

utility module.

jqmc.molecular_orbital module#

Molecular Orbital module.

class jqmc.molecular_orbital.MOs_data(num_mo=0, aos_data=<factory>, mo_coefficients=<factory>)#

Bases: object

Molecular orbital (MO) coefficients and metadata.

Holds the contraction matrix that maps atomic orbitals (AOs) to molecular orbitals (MOs). MO values are obtained as mo_coefficients @ AO_values in float64 (jax_enable_x64=True).

Variables:
  • num_mo (int) – Number of molecular orbitals.

  • aos_data (AOs_sphe_data | AOs_cart_data) – AO definition supplying centers, exponents/coefficients, angular data, and contraction mapping.

  • mo_coefficients (npt.NDArray | jax.Array) – Coefficient matrix of shape (num_mo, num_ao). Rows correspond to MOs; columns correspond to contracted AOs.

Parameters:
  • num_mo (int)

  • aos_data (AOs_sphe_data | AOs_cart_data)

  • mo_coefficients (ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex)

Examples

Minimal runnable setup (2 AOs -> 1 MO):

import numpy as np
from jqmc.structure import Structure_data
from jqmc.atomic_orbital import AOs_sphe_data
from jqmc.molecular_orbital import MOs_data

structure = Structure_data(
    positions=[[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]],
    pbc_flag=False,
    atomic_numbers=[1, 1],
    element_symbols=["H", "H"],
    atomic_labels=["H1", "H2"],
)

aos = AOs_sphe_data(
    structure_data=structure,
    nucleus_index=[0, 1],
    num_ao=2,
    num_ao_prim=2,
    orbital_indices=[0, 1],
    exponents=[1.0, 1.2],
    coefficients=[1.0, 0.8],
    angular_momentums=[0, 0],
    magnetic_quantum_numbers=[0, 0],
)
aos.sanity_check()

mo_coeffs = np.array([[0.7, 0.7]], dtype=float)  # shape (1, 2)
mos = MOs_data(num_mo=1, aos_data=aos, mo_coefficients=mo_coeffs)
mos.sanity_check()
aos_data: AOs_sphe_data | AOs_cart_data#

AO definition supplying centers, exponents/coefficients, angular data, and contraction mapping.

mo_coefficients: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex#

MO coefficient matrix, shape (num_mo, num_ao).

num_mo: int = 0#

Number of molecular orbitals.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Validate internal consistency.

Ensures mo_coefficients matches (num_mo, aos_data.num_ao), verifies num_mo is an int, and delegates AO validation to aos_data.sanity_check().

Raises:

ValueError – If coefficient shape or num_mo type is invalid, or if aos_data fails its check.

Return type:

None

property structure_data#

Return structure_data of the aos_data instance.

jqmc.molecular_orbital.compute_MOs(mos_data, r_carts)#

Evaluate molecular orbitals at electron coordinates.

Parameters:
  • mos_data (MOs_data) – MO/AO definition and coefficient matrix.

  • r_carts (jax.Array) – Electron Cartesian coordinates, shape (N_e, 3) in float64 (same convention as AO evaluators).

Returns:

MO values with shape (num_mo, N_e).

Return type:

jax.Array

jqmc.molecular_orbital.compute_MOs_grad(mos_data, r_carts)#

Compute MO gradients (x, y, z components) at electron coordinates.

Parameters:
  • mos_data (MOs_data) – MO/AO definition and coefficient matrix.

  • r_carts (jax.Array) – Electron Cartesian coordinates, shape (N_e, 3) in float64.

Returns:

Gradients per component (grad_x, grad_y, grad_z), each of shape (num_mo, N_e).

Return type:

tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]

jqmc.molecular_orbital.compute_MOs_laplacian(mos_data, r_carts)#

Compute MO laplacians at electron coordinates.

Parameters:
  • mos_data (MOs_data) – MO/AO definition and coefficient matrix.

  • r_carts (jax.Array) – Electron Cartesian coordinates, shape (N_e, 3) in float64.

Returns:

Laplacians of each MO, shape (num_mo, N_e).

Return type:

jax.Array

jqmc.setting module#

setting.

jqmc.structure module#

Structure module.

class jqmc.structure.Structure_data(positions=<factory>, pbc_flag=False, vec_a=<factory>, vec_b=<factory>, vec_c=<factory>, atomic_numbers=<factory>, element_symbols=<factory>, atomic_labels=<factory>)#

Bases: object

Atomic structure and cell metadata.

Stores Cartesian coordinates (Bohr), optional lattice vectors for PBC, and basic element metadata used by AOs/MOs, Coulomb, and Hamiltonian builders.

Variables:
  • positions (npt.NDArray | jax.Array) – Atomic Cartesian coordinates with shape (N, 3) in Bohr.

  • pbc_flag (bool) – Whether periodic boundary conditions are active. If True, lattice vectors vec_a|b|c must be provided; otherwise they must be empty.

  • vec_a (list[float] | tuple[float]) – Lattice vector a (Bohr) when pbc_flag=True.

  • vec_b (list[float] | tuple[float]) – Lattice vector b (Bohr) when pbc_flag=True.

  • vec_c (list[float] | tuple[float]) – Lattice vector c (Bohr) when pbc_flag=True.

  • atomic_numbers (list[int] | tuple[int]) – Atomic numbers Z for each site (len N).

  • element_symbols (list[str] | tuple[str]) – Element symbols for each site (len N).

  • atomic_labels (list[str] | tuple[str]) – Human-readable labels for each site (len N).

Parameters:
  • positions (ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex)

  • pbc_flag (bool)

  • vec_a (list[float] | tuple[float])

  • vec_b (list[float] | tuple[float])

  • vec_c (list[float] | tuple[float])

  • atomic_numbers (list[int] | tuple[int])

  • element_symbols (list[str] | tuple[str])

  • atomic_labels (list[str] | tuple[str])

Examples

Minimal H2 setup (Bohr):

import numpy as np
from jqmc.structure import Structure_data

structure = Structure_data(
    positions=np.array([[0.0, 0.0, -0.70], [0.0, 0.0, 0.70]]),
    pbc_flag=False,
    atomic_numbers=[1, 1],
    element_symbols=["H", "H"],
    atomic_labels=["H1", "H2"],
)
structure.sanity_check()
atomic_labels: list[str] | tuple[str]#

Human-readable labels per site (len N).

atomic_numbers: list[int] | tuple[int]#

Atomic numbers Z per site (len N).

property cell: ndarray[tuple[Any, ...], dtype[float64]]#

Lattice vectors as a (3, 3) matrix in Bohr ([a, b, c]).

element_symbols: list[str] | tuple[str]#

Element symbols per site (len N).

property lattice_vec_a: tuple#

Return lattice vector A (in Bohr).

Returns:

the lattice vector A (in Bohr).

Return type:

tuple[np.float64]

property lattice_vec_b: tuple#

Return lattice vector B (in Bohr).

Returns:

the lattice vector B (in Bohr).

Return type:

tuple[np.float64]

property lattice_vec_c: tuple#

Return lattice vector C (in Bohr).

Returns:

the lattice vector C (in Bohr).

Return type:

tuple[np.float64]

property natom: int#

The number of atoms in the system.

Returns:

The number of atoms in the system.

Return type:

int

property norm_vec_a: float#

Return the norm of the lattice vector A (in Bohr).

Returns:

the norm of the lattice vector A (in Bohr).

Return type:

np.float64

property norm_vec_b: float#

Return the norm of the lattice vector B (in Bohr).

Returns:

the norm of the lattice vector C (in Bohr).

Return type:

np.float64

property norm_vec_c: float#

Return the norm of the lattice vector C (in Bohr).

Returns:

the norm of the lattice vector C (in Bohr).

Return type:

np.float64

property ntyp: int#

The number of element types in the system.

Returns:

The number of element types in the system.

Return type:

int

pbc_flag: bool = False#

Whether periodic boundary conditions are active.

positions: ndarray[tuple[Any, ...], dtype[_ScalarT]] | Array | ndarray | bool | number | bool | int | float | complex#

Atomic Cartesian coordinates with shape (N, 3) in Bohr.

property recip_cell: ndarray[tuple[Any, ...], dtype[float64]]#

Reciprocal lattice vectors (3, 3) in Bohr^{-1}.

Uses the standard definition

\[G_a = 2\pi \frac{T_b \times T_c}{T_a \cdot (T_b \times T_c)}, \quad G_b = 2\pi \frac{T_c \times T_a}{T_b \cdot (T_c \times T_a)}, \quad G_c = 2\pi \frac{T_a \times T_b}{T_c \cdot (T_a \times T_b)},\]

and asserts the orthonormality condition \(T_i \cdot G_j = 2\pi\,\delta_{ij}\).

property recip_vec_a: tuple#

Return reciprocal lattice vector A (in Bohr).

Returns:

the reciprocal lattice vector A (in Bohr).

Return type:

tuple[np.float64]

property recip_vec_b: tuple#

Return reciprocal lattice vector B (in Bohr).

Returns:

the reciprocal lattice vector B (in Bohr).

Return type:

tuple[np.float64]

property recip_vec_c: tuple#

Return reciprocal lattice vector C (in Bohr).

Returns:

the reciprocal lattice vector C (in Bohr).

Return type:

tuple[np.float64]

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Validate consistency of positions, labels, and lattice metadata.

Ensures all per-atom arrays share length N, lattice vectors are either all present (length 3 each) when pbc_flag=True or all empty when pbc_flag=False, and basic types are as expected.

Raises:

ValueError – If list lengths mismatch, lattice vectors are inconsistent with pbc_flag, or field types are incorrect.

Return type:

None

vec_a: list[float] | tuple[float]#

Lattice vector a in Bohr (requires pbc_flag=True).

vec_b: list[float] | tuple[float]#

Lattice vector b in Bohr (requires pbc_flag=True).

vec_c: list[float] | tuple[float]#

Lattice vector c in Bohr (requires pbc_flag=True).

jqmc.swct module#

SWCT module.

class jqmc.swct.SWCT_data(structure)#

Bases: object

Space-warp coordinate transformation metadata.

Parameters:

structure (Structure_data) – Nuclear geometry used to build SWCT weights.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

structure: Structure_data#
jqmc.swct.evaluate_swct_domega(swct_data, r_carts)#

Evaluate \(\sum_i \nabla_{r_i} \omega_{\alpha i}\) for each atom.

Parameters:
  • swct_data (SWCT_data) – Structure and cached geometry information.

  • r_carts (jax.Array) – Electron Cartesian coordinates with shape (N_e, 3) and float64 dtype.

Returns:

Sum of gradients per atom with shape (N_a, 3).

Return type:

jax.Array

jqmc.swct.evaluate_swct_omega(swct_data, r_carts)#

Compute SWCT weights \(\omega_{\alpha i}\) for each atom/electron pair.

Parameters:
  • swct_data (SWCT_data) – Structure and cached geometry information.

  • r_carts (jax.Array) – Electron Cartesian coordinates with shape (N_e, 3) and float64 dtype.

Returns:

Normalized weights with shape (N_a, N_e), summing to 1 over atoms for each electron.

Return type:

jax.Array

jqmc.trexio_wrapper module#

TREXIO wrapper modules.

jqmc.trexio_wrapper.read_trexio_file(trexio_file, store_tuple=False)#

Load a TREXIO HDF5 file into jqmc data containers.

Parameters:
  • trexio_file (str) – Path to the TREXIO file to read (HDF5 backend expected).

  • store_tuple (bool, optional) – Store list-like fields as tuples for immutability (useful in tests with JAX/Flax), at the cost of slower production runs. Defaults to False.

Returns:

(structure_data, aos_data, mos_data_up, mos_data_dn, geminal_data, coulomb_potential_data) where:

  • structure_data is a Structure_data describing atoms and geometry.

  • aos_data is either AOs_cart_data or AOs_sphe_data depending on the basis.

  • mos_data_up and mos_data_dn are MOs_data for spin-up/down orbitals.

  • geminal_data is a Geminal_data assembled from the MO block.

  • coulomb_potential_data is a Coulomb_potential_data describing (E)CPs.

Return type:

tuple

Raises:
  • NotImplementedError – If periodic cells (PBC) or complex molecular orbitals are encountered.

  • ValueError – If atomic labels are unsupported or AO counts are inconsistent.

Notes

  • Periodic boundary conditions are parsed but not supported yet.

  • Molecular orbitals are assumed real-valued; complex coefficients are rejected.

Examples

>>> from jqmc.trexio_wrapper import read_trexio_file
>>> structure_data, aos_data, mos_up, mos_dn, geminal_data, coulomb_data = read_trexio_file("molecule.h5")
>>> structure_data.atomic_labels[:3]
['O', 'H', 'H']

jqmc.wavefunction module#

Wavefunction module.

class jqmc.wavefunction.VariationalParameterBlock(name, values, shape, size)#

Bases: object

A block of variational parameters (e.g., J1, J2, J3, lambda).

Design overview#

  • A block is the smallest unit that the optimizer (MCMC + SR) sees. Each block corresponds to a contiguous slice in the global variational parameter vector and carries enough metadata to reconstruct its original shape (name, values, shape, size).

  • This class is intentionally structure-agnostic: it does not know anything about Jastrow vs Geminal, matrix symmetry, or how a block maps to concrete fields in Jastrow_data or Geminal_data.

  • All physics- and structure-specific semantics are owned by the corresponding data classes via their get_variational_blocks and apply_block_update methods.

The goal is that adding or modifying a variational parameter only requires changes on the wavefunction side (Jastrow/Geminal data), while the MCMC/SR driver remains completely agnostic and operates purely on a list of blocks.

apply_update(delta_flat, learning_rate)#

Return a new block with values updated by a generic additive rule.

This method is intentionally structure-agnostic and only performs a simple additive update:

X_new = X_old + learning_rate * delta

Any parameter-specific constraints (e.g., symmetry of J3 or lambda_matrix) must be enforced by the owner of the parameter (jastrow_data, geminal_data, etc.) inside their apply_block_update implementations.

Parameters:
  • delta_flat (ndarray[tuple[Any, ...], dtype[_ScalarT]]) – Flattened update vector with length equal to size.

  • learning_rate (float) – Scaling factor for the update.

Return type:

VariationalParameterBlock

name: str#

Identifier for this block (for example "j1_param" or "lambda_matrix").

replace(**updates)#

Returns a new object replacing the specified fields with new values.

shape: tuple[int, ...]#

Original shape of values for unflattening updates.

size: int#

Flattened size of values used when slicing the global vector.

values: Array | ndarray | bool | number | bool | int | float | complex#

Parameter payload (keeps PyTree structure if present).

Parameters:
  • name (str)

  • values (Array | ndarray | bool | number | bool | int | float | complex)

  • shape (tuple[int, ...])

  • size (int)

class jqmc.wavefunction.Wavefunction_data(jastrow_data=<factory>, geminal_data=<factory>)#

Bases: object

Container for Jastrow and Geminal parts used to evaluate a wavefunction.

The class owns only the data needed to construct the wavefunction. All computations are delegated to the functions in this module and the underlying Jastrow/Geminal helpers.

Parameters:
  • jastrow_data (Jastrow_data) – Optional Jastrow parameters. If None, the Jastrow part is omitted.

  • geminal_data (Geminal_data) – Optional Geminal parameters. If None, the determinant part is omitted.

accumulate_position_grad(grad_wavefunction)#

Aggregate position gradients from geminal and Jastrow parts.

Parameters:

grad_wavefunction (Wavefunction_data)

apply_block_updates(blocks, thetas, learning_rate)#

Return a new Wavefunction_data with variational blocks updated.

Design notes#

  • blocks defines the ordering and shapes of all variational parameters; thetas is a single flattened update vector in the same order.

  • This method is responsible for slicing thetas into per-block pieces and performing a generic additive update via VariationalParameterBlock.apply_update().

  • The interpretation of each block (“this is J1”, “this is the J3 matrix”, “this is lambda”) and any structural constraints (symmetry, rectangular layout, etc.) are delegated to Jastrow_data.apply_block_update() and Geminal_data.apply_block_update().

Because of this separation of concerns, the MCMC/SR driver only needs to work with the flattened thetas vector and the list of blocks; it never touches Jastrow/Geminal internals directly. To add a new parameter to the optimization, one only needs to (1) expose it in get_variational_blocks(), and (2) handle it in the corresponding apply_block_update method.

Parameters:
Return type:

Wavefunction_data

collect_param_grads(grad_wavefunction)#

Collect parameter gradients from Jastrow and Geminal into a flat dict.

Parameters:

grad_wavefunction (Wavefunction_data)

Return type:

dict[str, object]

flatten_param_grads(param_grads, num_walkers)#

Return parameter gradients as numpy arrays ready for storage.

The caller does not need to know the internal block structure (e.g., NN trees); any necessary flattening is handled here.

Parameters:
  • param_grads (dict[str, object])

  • num_walkers (int)

Return type:

dict[str, ndarray]

geminal_data: Geminal_data#

Variational Geminal/determinant parameters.

get_variational_blocks(opt_J1_param=True, opt_J2_param=True, opt_J3_param=True, opt_JNN_param=True, opt_lambda_param=False)#

Collect variational parameter blocks from Jastrow and Geminal parts.

Each block corresponds to a contiguous group of variational parameters (e.g., J1, J2, J3 matrix, NN Jastrow, lambda matrix). This method only exposes the parameter arrays; the corresponding gradients are handled on the MCMC side.

Parameters:
  • opt_J1_param (bool)

  • opt_J2_param (bool)

  • opt_J3_param (bool)

  • opt_JNN_param (bool)

  • opt_lambda_param (bool)

Return type:

list[VariationalParameterBlock]

jastrow_data: Jastrow_data#

Variational Jastrow parameters.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

sanity_check()#

Check attributes of the class.

This function checks the consistencies among the arguments.

Raises:

ValueError – If there is an inconsistency in a dimension of a given argument.

Return type:

None

with_diff_mask(*, params=True, coords=True)#

Return a copy with gradients masked according to the provided flags.

Parameters:
  • params (bool)

  • coords (bool)

Return type:

Wavefunction_data

with_param_grad_mask(*, opt_J1_param=True, opt_J2_param=True, opt_J3_param=True, opt_JNN_param=True, opt_lambda_param=True)#

Return a copy where disabled parameter blocks stop propagating gradients.

Developer note#

  • The per-block flags (opt_J1_param etc.) decide which high-level blocks are

    masked. Disabled blocks are wrapped with DiffMask(params=False, coords=True), meaning parameter gradients are stopped while coordinate gradients still flow.

  • Within each disabled block, apply_diff_mask uses field-name heuristics

    (see diff_mask._PARAM_FIELD_NAMES) to tag parameter leaves such as lambda_matrix, j_matrix, jastrow_1b_param, jastrow_2b_param, jastrow_3b_param, and params. Those tagged leaves receive jax.lax.stop_gradient, so their backpropagated gradients become zero.

  • Example: if opt_J1_param=False and others are True, only the J1 block is

    masked; its parameter leaves are stopped, while J2/J3/NN/lambda continue to propagate gradients normally.

Parameters:
  • opt_J1_param (bool)

  • opt_J2_param (bool)

  • opt_J3_param (bool)

  • opt_JNN_param (bool)

  • opt_lambda_param (bool)

Return type:

Wavefunction_data

jqmc.wavefunction.compute_discretized_kinetic_energy(alat, wavefunction_data, r_up_carts, r_dn_carts, RT)#

Compute discretized kinetic mesh points and energies for a given lattice spacing alat.

Function for computing discretized kinetic grid points and their energies with a given lattice space (alat). This keeps the original semantics used by the LRDMC path: ratios are computed as exp(J_xp - J_x) * det_xp / det_x. Inputs are coerced to float64 jax.Array before evaluation.

Parameters:
  • alat (float) – Hamiltonian discretization (bohr), which will be replaced with LRDMC_data.

  • wavefunction_data – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Up-electron positions with shape (n_up, 3).

  • r_dn_carts (Array) – Down-electron positions with shape (n_dn, 3).

  • RT (Array) – Rotation matrix (\(R^T\)) with shape (3, 3).

Returns:

A tuple (r_up_carts_combined, r_dn_carts_combined, elements_kinetic_part) where the combined coordinate arrays have shapes (n_grid, n_up, 3) and (n_grid, n_dn, 3) and elements_kinetic_part contains the kinetic prefactor-scaled ratios.

Return type:

tuple[list[tuple[ndarray[tuple[Any, …], dtype[_ScalarT]], ndarray[tuple[Any, …], dtype[_ScalarT]]]], list[ndarray[tuple[Any, …], dtype[_ScalarT]]], Array]

jqmc.wavefunction.compute_discretized_kinetic_energy_fast_update(alat, wavefunction_data, A_old_inv, r_up_carts, r_dn_carts, RT)#

Fast-update version of discretized kinetic mesh and ratios.

Function for computing discretized kinetic grid points and their energies with a given lattice space (alat). Uses precomputed A_old_inv to evaluate determinant ratios efficiently. Inputs are converted to float64 jax.Array before use.

Parameters:
  • alat (float) – Hamiltonian discretization (bohr), which will be replaced with LRDMC_data.

  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • A_old_inv (Array) – Inverse of the geminal matrix evaluated at (r_up_carts, r_dn_carts).

  • r_up_carts (Array) – Up-electron positions with shape (n_up, 3).

  • r_dn_carts (Array) – Down-electron positions with shape (n_dn, 3).

  • RT (Array) – Rotation matrix (\(R^T\)) with shape (3, 3).

Returns:

Tuple (r_up_carts_combined, r_dn_carts_combined, elements_kinetic_part) with combined coordinate arrays of shapes (n_grid, n_up, 3) and (n_grid, n_dn, 3), and kinetic prefactor-scaled ratios elements_kinetic_part.

Return type:

tuple[Array, Array, Array]

jqmc.wavefunction.compute_kinetic_energy(wavefunction_data, r_up_carts, r_dn_carts)#

Compute kinetic energy using analytic gradients and Laplacians.

The method is for computing kinetic energy of the given WF at (r_up_carts, r_dn_carts) and fully exploits the JAX library for the kinetic energy calculation. Inputs are converted to float64 jax.Array for consistency with other compute utilities.

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Kinetic energy evaluated for the supplied configuration.

Return type:

float | complex

jqmc.wavefunction.compute_kinetic_energy_all_elements(wavefunction_data, r_up_carts, r_dn_carts)#

Analytic-derivative kinetic energy per electron (matches auto output shape).

Returns the per-electron kinetic energy using analytic gradients/Laplacians of both Jastrow and determinant parts. Shapes align with _compute_kinetic_energy_all_elements_auto.

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Tuple of two jax.Array objects containing per-electron kinetic energies for spin-up and spin-down electrons, respectively.

Return type:

Array

jqmc.wavefunction.compute_kinetic_energy_all_elements_fast_update(wavefunction_data, r_up_carts, r_dn_carts, geminal_inverse)#

Kinetic energy per electron using a precomputed geminal inverse.

Parameters:
  • wavefunction_data (Wavefunction_data)

  • r_up_carts (Array)

  • r_dn_carts (Array)

  • geminal_inverse (Array)

Return type:

Array

jqmc.wavefunction.compute_quantum_force(wavefunction_data, r_up_carts, r_dn_carts)#

Compute quantum forces 2 * grad ln |Psi| at the given coordinates.

The method is for computing quantum forces at (r_up_carts, r_dn_carts). Gradients from the Jastrow part are currently set to zero (as in the original implementation); determinant gradients are included via compute_grads_and_laplacian_ln_Det. Inputs are coerced to float64 jax.Array for consistency.

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Tuple (force_up, force_dn) with shapes matching the input coordinate arrays.

Return type:

tuple[Array, Array]

jqmc.wavefunction.evaluate_determinant(wavefunction_data, r_up_carts, r_dn_carts)#

Evaluate the determinant (Geminal) part of the wavefunction.

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Determinant value evaluated at the supplied coordinates.

Return type:

float

jqmc.wavefunction.evaluate_jastrow(wavefunction_data, r_up_carts, r_dn_carts)#

Evaluate the Jastrow factor \(\exp(J)\) at the given coordinates.

The method is for evaluate the Jastrow part of the wavefunction (Psi) at (r_up_carts, r_dn_carts). The returned value already includes the exponential, i.e., exp(J).

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Real Jastrow factor exp(J).

Return type:

float

jqmc.wavefunction.evaluate_ln_wavefunction(wavefunction_data, r_up_carts, r_dn_carts)#

Evaluate the logarithm of |wavefunction| (\(\ln |\Psi|\)).

This follows the original behavior: compute the Jastrow part, multiply the determinant part, and then take log(abs(det)) while keeping the full Jastrow contribution. The inputs are converted to float64 jax.Array for downstream consistency.

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Scalar log-value of the wavefunction magnitude.

Return type:

float

jqmc.wavefunction.evaluate_wavefunction(wavefunction_data, r_up_carts, r_dn_carts)#

Evaluate the wavefunction Psi at given electron coordinates.

The method is for evaluate wavefunction (Psi) at (r_up_carts, r_dn_carts) and returns exp(Jastrow) * Determinant. Inputs are coerced to float64 jax.Array to match other compute utilities.

Parameters:
  • wavefunction_data (Wavefunction_data) – Wavefunction parameters (Jastrow + Geminal).

  • r_up_carts (Array) – Cartesian coordinates of up-spin electrons with shape (n_up, 3).

  • r_dn_carts (Array) – Cartesian coordinates of down-spin electrons with shape (n_dn, 3).

Returns:

Complex or real wavefunction value.

Return type:

float | complex

Module contents#