Skip to content
Open

GPU #85

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions TB2J/GreenGPU.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
GreenGPU: GPU-accelerated Green's function computation using JAX.
"""

import copy
import time

import numpy as np
from HamiltonIO.model.occupations import GaussOccupations

from TB2J.green import TBGreen, find_energy_ingap


class TBGreenGPU(TBGreen):
"""
GPU-accelerated version of TBGreen using JAX.

This class inherits from TBGreen and overrides the eigenvalue/eigenvector
preparation to use GPU acceleration via JAX.
"""

def __init__(
self,
tbmodel,
kmesh=None,
ibz=False,
efermi=None,
gamma=False,
kpts=None,
kweights=None,
k_sym=False,
use_cache=False,
cache_path=None,
nproc=1,
initial_emin=-25,
smearing_width=0.01,
):
"""
Initialize TBGreenGPU.

Note: use_gpu parameter is not needed here as this class is specifically
for GPU computation. If GPU is not available, an error will be raised.
"""
# Initialize base attributes (skip TBGreen.__init__ to avoid calling _prepare_eigen)
self.initial_emin = initial_emin
self.tbmodel = tbmodel
self.is_orthogonal = tbmodel.is_orthogonal
self.R2kfactor = tbmodel.R2kfactor
self.k2Rfactor = -tbmodel.R2kfactor
self.efermi = efermi
self._use_cache = use_cache
self.cache_path = cache_path
self.use_gpu = True

if use_cache:
self._prepare_cache()

self.prepare_kpts(
kmesh=kmesh,
ibz=ibz,
gamma=gamma,
kpts=kpts,
kweights=kweights,
tbmodel=tbmodel,
)

self.norb = tbmodel.norb
self.nbasis = tbmodel.nbasis
self.k_sym = k_sym
self.nproc = nproc
self.fermi_width = float(smearing_width)

# Initialize Rmap for spin-phonon coupling
self._Rmap = None
self._Rmap_rev = None

print(
f"starting to prepare eigenvalues and eigenvectors for {self.nkpts} k-points..."
)
t0 = time.time()

# Call GPU eigen preparation
self._prepare_eigen()

print(
f"Finished preparing eigenvalues and eigenvectors. Time taken: {time.time() - t0:.2f} seconds"
)

def _prepare_eigen(self, solve=True, saveH=False):
"""
Calculate eigenvalues and eigenvectors for all k-points using GPU acceleration.

Uses JAX for GPU-accelerated computation of H(k), S(k) and their
eigenvalue decomposition.
"""
import jax.numpy as jnp

from TB2J.gpu.jax_utils import (
_compute_Hk_Sk_all_jax,
_prepare_eigen_gpu,
_prepare_HR_jax,
)

nkpts = len(self.kpts)
self.nkpts = nkpts
self.evals = np.zeros((nkpts, self.nbasis), dtype=float)
self.H0 = np.zeros((self.nbasis, self.nbasis), dtype=complex)
self.evecs = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex)
self.H = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex)
if not self.is_orthogonal:
self.S = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex)
else:
self.S = None

print("Using GPU for eigenvalue/eigenvector computation...")

# Compute eigenvalues and eigenvectors on GPU
self.evals, self.evecs, Sk_all = _prepare_eigen_gpu(self.tbmodel, self.kpts)

# H is not computed on GPU, set to None
self.H = None

# Store S(k) if non-orthogonal
if Sk_all is not None:
self.S = Sk_all
else:
self.S = None

# Compute H0 by averaging H(k) over all k-points on GPU
Rpts_jax, HR_jax, SR_jax, R2kfactor = _prepare_HR_jax(self.tbmodel)
kpts_jax = jnp.array(self.kpts)
Hk_all, _ = _compute_Hk_Sk_all_jax(
Rpts_jax, HR_jax, SR_jax, kpts_jax, R2kfactor
)
self.H0 = np.array(jnp.mean(Hk_all, axis=0))

# Get Fermi energy if not provided
if self.efermi is None:
print("Calculating Fermi energy from eigenvalues")
print(f"Number of electrons: {self.tbmodel.nel} ")

occ = GaussOccupations(
nel=self.tbmodel.nel, width=0.1, wk=self.kweights, nspin=2
)
self.efermi = occ.efermi(copy.deepcopy(self.evals))
print(f"Fermi energy found: {self.efermi}")

# Adjust energy range
self.adjusted_emin = (
find_energy_ingap(
self.evals, rbound=self.efermi + self.initial_emin, gap=2.0
)
- self.efermi
)

self.evals, self.evecs = self._reduce_eigens(
self.evals,
self.evecs,
emin=self.efermi + self.adjusted_emin,
emax=self.efermi + 5.1,
)

# Handle caching if enabled
if self._use_cache:
evecs = self.evecs
self.evecs_shape = self.evecs.shape
self.evecs = np.memmap(
self._get_cache_path("evecs.dat"),
mode="w+",
shape=self.evecs.shape,
dtype=complex,
)
self.evecs[:, :, :] = evecs[:, :, :]
del self.evecs

if not self.is_orthogonal:
S = self.S
self.S = np.memmap(
self._get_cache_path("S.dat"),
mode="w+",
shape=(nkpts, self.nbasis, self.nbasis),
dtype=complex,
)
self.S[:] = S[:]
del self.S

def _get_cache_path(self, filename):
"""Get full path for cache file."""
import os

return os.path.join(self.cache_path, filename)
18 changes: 10 additions & 8 deletions TB2J/sisl_wrapper.py → TB2J/deprecated/sisl_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from collections import defaultdict

import numpy as np
from ase.atoms import Atoms
from TB2J.utils import symbol_number
from collections import defaultdict
from scipy.linalg import eigh
from TB2J.myTB import AbstractTB

from TB2J.mathutils import Lowdin
from TB2J.myTB import AbstractTB
from TB2J.utils import symbol_number


class SislWrapper(AbstractTB):
Expand Down Expand Up @@ -42,7 +44,7 @@ def __init__(self, sisl_hamiltonian, geom=None, spin=None):
symnum = sdict[ia]
try:
orb_names = [f"{symnum}|{x.name()}|up" for x in a.orbital]
except:
except Exception:
orb_names = [f"{symnum}|{x.name()}|up" for x in a.orbitals]
self.orbs += orb_names
self.orb_dict[ia] += orb_names
Expand Down Expand Up @@ -170,12 +172,12 @@ def HSE_k(self, k, convention=2):

def HS_and_eigen(self, kpts, convention=2):
nkpts = len(kpts)
evals = np.zeros((nkpts, self.nbasis), dtype=float)
evals = np.zeros((nkpts, self.nbasis), dtype=float) # noqa: F841
self.nkpts = nkpts
if not self._use_cache:
evecs = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex)
H = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex)
S = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex)
evecs = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex) # noqa: F841
H = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex) # noqa: F841
S = np.zeros((nkpts, self.nbasis, self.nbasis), dtype=complex) # noqa: F841
else:
self._prepare_cache()

Expand Down
1 change: 1 addition & 0 deletions TB2J/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def set_tbmodels(self, tbmodels):
nproc=self.nproc,
initial_emin=self.emin,
smearing_width=self.smearing,
use_gpu=self.use_gpu,
)
if self.efermi is None:
self.efermi = self.G.efermi
Expand Down
56 changes: 54 additions & 2 deletions TB2J/exchange_params.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import argparse
from dataclasses import dataclass

import ase.units
import yaml

__all__ = ["ExchangeParams", "add_exchange_args_to_parser", "parser_argument_to_dict"]


@dataclass
class ExchangeParams:
"""
A class to store the parameters for exchange calculation.
Expand Down Expand Up @@ -35,6 +33,9 @@ class ExchangeParams:
mae_angles = None
orth = False
ibz = False
use_gpu = False
vectorize_energy = False
e_batch_size = None
# Debug options
debug_options = {
"compute_charge_moments": False, # Whether to compute charge and magnetic moments with Green's function method
Expand Down Expand Up @@ -66,8 +67,31 @@ def __init__(
orth=False,
ibz=False,
index_magnetic_atoms=None,
use_gpu=None,
vectorize_energy=False,
e_batch_size=None,
debug_options=None,
**kwargs,
):
# Only check for GPU if use_gpu is explicitly True or None
# Note: setting use_gpu=None no longer auto-detects GPU to avoid
# importing JAX when not needed (which allocates GPU memory)
if use_gpu is True:
from TB2J.gpu.jax_utils import _is_gpu_available

platform = _is_gpu_available()
if platform:
print(f"Using GPU ({platform}) for acceleration.")
else:
print("GPU requested but not available, falling back to CPU.")
use_gpu = False
else:
platform = None
use_gpu = False

if not use_gpu:
print("Using CPU for calculation.")

self.efermi = efermi
self.smearing = smearing
self.basis = basis
Expand All @@ -93,6 +117,13 @@ def __init__(
self.orth = orth
self.ibz = ibz
self.index_magnetic_atoms = index_magnetic_atoms
self.use_gpu = use_gpu
self.vectorize_energy = vectorize_energy
self.e_batch_size = e_batch_size

# Save other kwargs
for key, val in kwargs.items():
setattr(self, key, val)

# Initialize debug options
if debug_options is None:
Expand Down Expand Up @@ -279,6 +310,24 @@ def add_exchange_args_to_parser(parser: argparse.ArgumentParser):
nargs="*",
default=None,
)
parser.add_argument(
"--use_gpu",
help="Whether to use GPU acceleration (requires JAX). Default is True if JAX/CUDA is found.",
action="store_true",
default=None,
)
parser.add_argument(
"--vectorize_energy",
help="Whether to vectorize over the entire energy contour on GPU. Default: False",
action="store_true",
default=False,
)
parser.add_argument(
"--e_batch_size",
help="Batch size for energy points on GPU. Default: None",
type=int,
default=None,
)

return parser

Expand Down Expand Up @@ -321,4 +370,7 @@ def parser_argument_to_dict(args) -> dict:
"output_path": args.output_path,
"orth": args.orth,
"index_magnetic_atoms": ind_mag_atoms,
"use_gpu": args.use_gpu,
"vectorize_energy": args.vectorize_energy,
"e_batch_size": args.e_batch_size,
}
7 changes: 2 additions & 5 deletions TB2J/exchange_pert2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,10 @@ def _prepare_Patom(self):
uu, dd, ud, du = H[:ni, :ni], H[ni:, ni:], H[:ni, ni:], H[ni:, :ni]
else:
uu, dd, ud, du = H[::2, ::2], H[1::2, 1::2], H[::2, 1::2], H[1::2, ::2]
# Standard Pauli components: m3 = (uu - dd)/2
# But the physics of G says Tr(uu) > Tr(dd) (Up-moment)
# So H_uu must be lower than H_dd.
# Thus P = (H_uu - H_dd)/2 should be negative.
# Standard Pauli components
m1 = (ud + du) / 2.0
m2 = (ud - du) * 0.5j
m3 = (dd - uu) / 2.0 # Inverted to match moment sign
m3 = (uu - dd) / 2.0
ex, ey, ez = np.trace(m1).real, np.trace(m2).real, np.trace(m3).real
evec = np.array((ex, ey, ez))
norm = np.linalg.norm(evec)
Expand Down
17 changes: 17 additions & 0 deletions TB2J/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from TB2J.gpu.exchange_ncl_gpu import ExchangeGPU, ExchangeNCLGPU
from TB2J.gpu.exchange_pert2_gpu import ExchangePert2GPU
from TB2J.gpu.exchangeCL_gpu import ExchangeCL2GPU, ExchangeCLGPU
from TB2J.gpu.jax_utils import _check_jax, _require_jax
from TB2J.gpu.mae_green_gpu import MAEGreenGPU, MAEGreenJAX

__all__ = [
"ExchangeGPU",
"ExchangeNCLGPU",
"ExchangePert2GPU",
"ExchangeCL2GPU",
"ExchangeCLGPU",
"_check_jax",
"_require_jax",
"MAEGreenGPU",
"MAEGreenJAX",
]
Loading