Commit 0649ba8f authored by Victor Yu's avatar Victor Yu
Browse files

Add distributed band parallelization to Wstat

This feature helps to reduce memory per image. Can be enabled by
specifying `-nb xxx` at the command line (see QE manual).
parent 9cdb9297
......@@ -54,7 +54,7 @@ stages:
- make conf PYT=python3 PYT_LDFLAGS="`python3-config --ldflags --embed`"
- make -j4 all
- cd test-suite
- make NP=$CI_NP NI=$CI_NI NT=$CI_NT
- make NP=$CI_NP NI=$CI_NI NB=$CI_NB NT=$CI_NT
artifacts:
when: on_failure
paths:
......@@ -83,6 +83,7 @@ gcc840_t:
variables:
CI_NP: 8
CI_NI: 1
CI_NB: 1
CI_NT: 1
extends:
- .template_bot_start
......@@ -92,7 +93,8 @@ gcc840_t:
gcc930_t:
variables:
CI_NP: 8
CI_NI: 1
CI_NI: 2
CI_NB: 2
CI_NT: 1
extends:
- .template_bot_start
......@@ -103,9 +105,10 @@ gcc930_t:
gcc840_t2:
variables:
CI_NP: 4
CI_NP: 8
CI_NI: 2
CI_NT: 1
CI_NB: 1
CI_NT: 2
only:
- schedules
extends:
......@@ -115,8 +118,9 @@ gcc840_t2:
gcc930_t2:
variables:
CI_NP: 2
CI_NI: 2
CI_NP: 8
CI_NI: 1
CI_NB: 2
CI_NT: 2
only:
- schedules
......
......@@ -9,6 +9,7 @@ MODFLAGS= $(MOD_FLAG)../../iotk/src $(MOD_FLAG)../../Modules $(MOD_FLAG)../../LA
$(MOD_FLAG)../Tools \
$(MOD_FLAG)../FFT_kernel \
$(MOD_FLAG)../Coulomb_kernel \
$(MOD_FLAG)../Para_kernel \
$(MOD_FLAG).
IFLAGS=
......
......@@ -54,7 +54,11 @@ SUBROUTINE apply_sternheimerop_to_m_wfcs(nbndval, psi, hpsi, e, alpha, m)
IF(l_kinetic_only) THEN
CALL k_psi( npwx, npw, m, psi, hpsi )
ELSE
CALL h_psi( npwx, npw, m, psi, hpsi )
!
! use h_psi_, i.e. h_psi without band parallelization, as wstat
! handles band parallelization separately in dfpt_module
!
CALL h_psi_( npwx, npw, m, psi, hpsi )
ENDIF
!
! then we compute the operator H-epsilon S
......
......@@ -23,31 +23,29 @@ MODULE dfpt_module
!-----------------------------------------------------------------------
!
USE kinds, ONLY : DP
USE constants, ONLY : tpi
USE io_global, ONLY : stdout
USE wvfct, ONLY : nbnd,g2kin,et
USE fft_base, ONLY : dfftp,dffts
USE gvect, ONLY : nl,nl,gstart,g,ngm
USE wvfct, ONLY : nbnd,et
USE fft_base, ONLY : dffts
USE gvect, ONLY : gstart
USE wavefunctions_module, ONLY : evc,psic
USE gvecw, ONLY : gcutw
USE mp, ONLY : mp_sum,mp_barrier,mp_bcast
USE mp_global, ONLY : inter_image_comm,inter_pool_comm,my_image_id
USE mp_global, ONLY : inter_image_comm,inter_pool_comm,my_image_id,inter_bgrp_comm
USE fft_at_k, ONLY : single_fwfft_k,single_invfft_k
USE fft_at_gamma, ONLY : single_fwfft_gamma,single_invfft_gamma,double_fwfft_gamma,double_invfft_gamma
USE fft_interfaces, ONLY : fwfft, invfft
USE buffers, ONLY : get_buffer
USE noncollin_module, ONLY : noncolin,npol
USE bar, ONLY : bar_type,start_bar_type,update_bar_type,stop_bar_type
USE pwcom, ONLY : current_spin,wk,nks,nelup,neldw,isk,xk,npw,npwx,lsda,nkstot,&
USE pwcom, ONLY : current_spin,nelup,neldw,isk,xk,npw,npwx,lsda,&
& current_k,ngk,igk_k
USE cell_base, ONLY : tpiba2,omega,at
USE control_flags, ONLY : gamma_only, io_level
USE io_files, ONLY : tmp_dir, nwordwfc, iunwfc, diropn
USE uspp, ONLY : nkb, vkb, okvan
USE cell_base, ONLY : omega
USE control_flags, ONLY : gamma_only
USE uspp, ONLY : nkb, vkb
USE westcom, ONLY : nbnd_occ,iuwfc,lrwfc,npwqx,npwq,igq_q,fftdriver
USE io_push, ONLY : io_push_title
USE mp_world, ONLY : mpime,world_comm
USE mp_world, ONLY : world_comm
USE types_bz_grid, ONLY : k_grid, q_grid, compute_phase
USE class_idistribute, ONLY : idistribute
USE distribution_center, ONLY : occband
!
IMPLICIT NONE
!
......@@ -61,14 +59,15 @@ MODULE dfpt_module
!
! Workspace
!
INTEGER :: ipert, ig, ir, ibnd, iks, ikqs, ikq, ik, is
INTEGER :: i, j, k
INTEGER :: ipert, ig, ir, ibnd, ibnd2, lbnd, iks, ikqs, ikq, ik, is
INTEGER :: nbndval, ierr
INTEGER :: npwkq
!
REAL(DP) :: g0(3)
REAL(DP) :: anorm
REAL(DP), ALLOCATABLE :: eprec(:)
REAL(DP), ALLOCATABLE :: eprec_loc(:)
REAL(DP), ALLOCATABLE :: et_loc(:)
!
COMPLEX(DP), ALLOCATABLE :: dvpsi(:,:),dpsi(:,:)
COMPLEX(DP), ALLOCATABLE :: aux_r(:),aux_g(:)
......@@ -79,8 +78,6 @@ MODULE dfpt_module
!
TYPE(bar_type) :: barra
!
LOGICAL :: conv_dfpt
LOGICAL :: exst,exst_mem
LOGICAL :: l_dost
!
CHARACTER(LEN=512) :: title
......@@ -98,6 +95,8 @@ MODULE dfpt_module
ENDIF
CALL io_push_title(TRIM(ADJUSTL(title)))
!
occband = idistribute()
!
dng=0.0_DP
!
CALL start_bar_type( barra, 'dfpt', MAX(m,1) * k_grid%nps )
......@@ -124,6 +123,8 @@ MODULE dfpt_module
!
nbndval = nbnd_occ(iks)
!
CALL occband%init( nbndval, 'b', 'occband', .FALSE. )
!
! ... Number of G vectors for PW expansion of wfs at k
!
npw = ngk(iks)
......@@ -163,10 +164,18 @@ MODULE dfpt_module
!
!
ALLOCATE(eprec(nbndval))
ALLOCATE(eprec_loc(occband%nloc))
ALLOCATE(et_loc(occband%nloc))
CALL set_eprec(nbndval,evc(1,1),eprec)
!
ALLOCATE(dvpsi(npwx*npol,nbndval))
ALLOCATE(dpsi(npwx*npol,nbndval))
DO lbnd = 1,occband%nloc
ibnd = occband%l2g(lbnd)
eprec_loc(lbnd) = eprec(ibnd)
et_loc(lbnd) = et(ibnd,ikqs)
ENDDO
!
ALLOCATE(dvpsi(npwx*npol,occband%nloc))
ALLOCATE(dpsi(npwx*npol,occband%nloc))
!
DO ipert = 1, m
!
......@@ -193,31 +202,38 @@ MODULE dfpt_module
IF(gamma_only) THEN
!
! double bands @ gamma
DO ibnd=1,nbndval-MOD(nbndval,2),2
DO lbnd = 1,occband%nloc-MOD(occband%nloc,2),2
!
CALL double_invfft_gamma(dffts,npw,npwx,evc(1,ibnd),evc(1,ibnd+1),psic,'Wave')
ibnd = occband%l2g(lbnd)
ibnd2 = occband%l2g(lbnd+1)
!
CALL double_invfft_gamma(dffts,npw,npwx,evc(1,ibnd),evc(1,ibnd2),psic,'Wave')
DO CONCURRENT (ir=1:dffts%nnr)
psic(ir) = psic(ir) * REAL(aux_r(ir),KIND=DP)
ENDDO
CALL double_fwfft_gamma(dffts,npw,npwx,psic,dvpsi(1,ibnd),dvpsi(1,ibnd+1),'Wave')
CALL double_fwfft_gamma(dffts,npw,npwx,psic,dvpsi(1,lbnd),dvpsi(1,lbnd+1),'Wave')
!
ENDDO
!
! single band @ gamma
IF( MOD(nbndval,2) == 1 ) THEN
ibnd=nbndval
IF( MOD(occband%nloc,2) == 1 ) THEN
!
lbnd = occband%nloc
ibnd = occband%l2g(lbnd)
!
CALL single_invfft_gamma(dffts,npw,npwx,evc(1,ibnd),psic,'Wave')
DO CONCURRENT (ir=1:dffts%nnr)
psic(ir) = CMPLX( REAL(psic(ir),KIND=DP) * REAL(aux_r(ir),KIND=DP), 0._DP, KIND=DP)
ENDDO
CALL single_fwfft_gamma(dffts,npw,npwx,psic,dvpsi(1,ibnd),'Wave')
CALL single_fwfft_gamma(dffts,npw,npwx,psic,dvpsi(1,lbnd),'Wave')
!
ENDIF
!
ELSE
!
DO ibnd = 1, nbndval
DO lbnd = 1,occband%nloc
!
ibnd = occband%l2g(lbnd)
!
! ... inverse Fourier transform of wfs at [k-q]: (k-q+)G ---> R
!
......@@ -233,14 +249,16 @@ MODULE dfpt_module
! Fourier transform product of wf at [k-q], phase and
! perturbation of wavevector q: R ---> (k+)G
!
CALL single_fwfft_k(dffts,npw,npwx,psic,dvpsi(1,ibnd),'Wave',igk_k(1,iks))
CALL single_fwfft_k(dffts,npw,npwx,psic,dvpsi(1,lbnd),'Wave',igk_k(1,iks))
!
! dv|psi> is in dvpsi
!
ENDDO
!
IF (noncolin) THEN
DO ibnd = 1, nbndval
DO lbnd = 1,occband%nloc
!
ibnd = occband%l2g(lbnd)
!
CALL single_invfft_k(dffts,npwkq,npwx,evckmq(npwx+1,ibnd),psic,'Wave',igk_k(1,ikqs))
!
......@@ -248,7 +266,7 @@ MODULE dfpt_module
psic(ir) = psic(ir) * phase(ir) * aux_r(ir)
ENDDO
!
CALL single_fwfft_k(dffts,npw,npwx,psic,dvpsi(npwx+1,ibnd),'Wave',igk_k(1,iks))
CALL single_fwfft_k(dffts,npw,npwx,psic,dvpsi(npwx+1,lbnd),'Wave',igk_k(1,iks))
!
ENDDO
ENDIF
......@@ -260,9 +278,9 @@ MODULE dfpt_module
!
! - P_c | dvpsi >
!
CALL apply_alpha_pc_to_m_wfcs( nbndval, nbndval, dvpsi, (-1._DP,0._DP) )
CALL apply_alpha_pc_to_m_wfcs( nbndval, occband%nloc, dvpsi, (-1._DP,0._DP) )
!
CALL precondition_m_wfcts( nbndval, dvpsi, dpsi, eprec )
CALL precondition_m_wfcts( occband%nloc, dvpsi, dpsi, eprec_loc )
!
IF( l_dost) THEN
!
......@@ -270,7 +288,7 @@ MODULE dfpt_module
! The Hamiltonian is evaluated at the k-point current_k in h_psi
! (see also PHonon/PH/cch_psi_all.f90, where H_(k+q) is evaluated)
!
CALL linsolve_sternheimer_m_wfcts (nbndval, nbndval, dvpsi, dpsi, et(1,ikqs), eprec, tr2, ierr )
CALL linsolve_sternheimer_m_wfcts (nbndval, occband%nloc, dvpsi, dpsi, et_loc, eprec_loc, tr2, ierr )
!
IF(ierr/=0) THEN
WRITE(stdout, '(7X,"** WARNING : PERT ",i8," iks ",I8," not converged, ierr = ",i8)') ipert,iks,ierr
......@@ -285,9 +303,11 @@ MODULE dfpt_module
IF(gamma_only) THEN
!
! double band @ gamma
DO ibnd=1,nbndval
DO lbnd = 1,occband%nloc
!
ibnd = occband%l2g(lbnd)
!
CALL double_invfft_gamma(dffts,npw,npwx,evc(1,ibnd),dpsi(1,ibnd),psic,'Wave')
CALL double_invfft_gamma(dffts,npw,npwx,evc(1,ibnd),dpsi(1,lbnd),psic,'Wave')
DO CONCURRENT (ir=1:dffts%nnr)
aux_r(ir) = aux_r(ir) + CMPLX( REAL( psic(ir),KIND=DP) * DIMAG( psic(ir)) , 0.0_DP, KIND=DP)
ENDDO
......@@ -298,7 +318,9 @@ MODULE dfpt_module
!
ALLOCATE( dpsic(dffts%nnr) )
!
DO ibnd = 1, nbndval
DO lbnd = 1,occband%nloc
!
ibnd = occband%l2g(lbnd)
!
! inverse Fourier transform of wavefunction at [k-q]: (k-q+)G ---> R
!
......@@ -306,7 +328,7 @@ MODULE dfpt_module
!
! inverse Fourier transform of perturbed wavefunction: (k+)G ---> R
!
CALL single_invfft_k(dffts,npw,npwx,dpsi(1,ibnd),dpsic,'Wave',igk_k(1,iks))
CALL single_invfft_k(dffts,npw,npwx,dpsi(1,lbnd),dpsic,'Wave',igk_k(1,iks))
!
DO CONCURRENT (ir = 1: dffts%nnr)
aux_r(ir) = aux_r(ir) + CONJG( psic(ir) * phase(ir) ) * dpsic(ir)
......@@ -315,11 +337,13 @@ MODULE dfpt_module
ENDDO
!
IF (noncolin) THEN
DO ibnd = 1, nbndval
DO lbnd = 1,occband%nloc
!
ibnd = occband%l2g(lbnd)
!
CALL single_invfft_k(dffts,npwkq,npwx,evckmq(npwx+1,ibnd),psic,'Wave',igk_k(1,ikqs))
!
CALL single_invfft_k(dffts,npw,npwx,dpsi(npwx+1,ibnd),dpsic,'Wave',igk_k(1,iks))
CALL single_invfft_k(dffts,npw,npwx,dpsi(npwx+1,lbnd),dpsic,'Wave',igk_k(1,iks))
!
DO CONCURRENT (ir = 1: dffts%nnr)
aux_r(ir) = aux_r(ir) + CONJG( psic(ir) * phase(ir) ) * dpsic(ir)
......@@ -332,6 +356,10 @@ MODULE dfpt_module
!
ENDIF
!
! Sum up aux_r from band groups
!
CALL mp_sum(aux_r,inter_bgrp_comm)
!
! The perturbation is in aux_r
!
ALLOCATE( aux_g(npwqx) )
......@@ -356,6 +384,8 @@ MODULE dfpt_module
IF( m == 0 ) CALL update_bar_type( barra, 'dfpt', 1 )
!
DEALLOCATE( eprec )
DEALLOCATE( eprec_loc )
DEALLOCATE( et_loc )
DEALLOCATE( dpsi )
DEALLOCATE( dvpsi )
!
......
......@@ -22,6 +22,7 @@ SUBROUTINE wstat_memory_report()
USE gvecs, ONLY : ngms
USE uspp, ONLY : nkb
USE control_flags, ONLY : gamma_only
USE mp_bands, ONLY : nbgrp
USE mp_world, ONLY : mpime,root
USE westcom, ONLY : nbnd_occ,n_pdep_basis,npwqx,logfile
USE distribution_center, ONLY : pert
......@@ -141,15 +142,15 @@ SUBROUTINE wstat_memory_report()
WRITE(stdout,'(5x,"[MEM] Allocated arrays ",5x,"est. size (Mb)", 5x,"dimensions")')
WRITE(stdout,'(5x,"[MEM] ----------------------------------------------------------")')
!
mem_partial = (1.0_DP/Mb)*complex_size*npwx*npol*nbnd_occ(1)
mem_partial = (1.0_DP/Mb)*complex_size*npwx*npol*((nbnd_occ(1)-1)/nbgrp+1)
WRITE( stdout, '(5x,"[MEM] dvpsi ",f10.2," Mb", 5x,"(",i7,",",i5,")")') &
mem_partial, npwx*npol, nbnd_occ(1)
mem_partial, npwx*npol, ((nbnd_occ(1)-1)/nbgrp+1)
IF( mpime == root ) CALL json%add( 'memory.dvpsi', mem_partial )
mem_tot = mem_tot + mem_partial
!
mem_partial = (1.0_DP/Mb)*complex_size*npwx*npol*nbnd_occ(1)
mem_partial = (1.0_DP/Mb)*complex_size*npwx*npol*((nbnd_occ(1)-1)/nbgrp+1)
WRITE( stdout, '(5x,"[MEM] dpsi ",f10.2," Mb", 5x,"(",i7,",",i5,")")') &
mem_partial, npwx*npol, nbnd_occ(1)
mem_partial, npwx*npol, ((nbnd_occ(1)-1)/nbgrp+1)
IF( mpime == root ) CALL json%add( 'memory.dpsi', mem_partial )
mem_tot = mem_tot + mem_partial
!
......
......@@ -4,6 +4,7 @@
export NP=2 # Number of MPI processes
export NI=1 # Number of images
export NB=1 # Number of band groups
export NT=1 # Number of OpenMP threads
#
......@@ -25,5 +26,6 @@ export WGET=wget -N -q
###### DO NOT TOUCH BELOW ######
export NIMAGE=${NI}
export NBAND=${NB}
export OMP_NUM_THREADS=${NT}
......@@ -15,7 +15,7 @@ pw:
${PARA_PREFIX_QE} ${BINDIR}/pw.x -i pw.in > pw.out 2> pw.err
wstat: pw
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -i wstat.in > wstat.out 2> wstat.err
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -nband ${NBAND} -i wstat.in > wstat.out 2> wstat.err
wfreq: wstat
${PARA_PREFIX} ${BINDIR}/wfreq.x -nimage ${NIMAGE} -i wfreq.in > wfreq.out 2> wfreq.err
......
......@@ -15,7 +15,7 @@ pw:
${PARA_PREFIX_QE} ${BINDIR}/pw.x -i pw.in > pw.out 2> pw.err
wstat: pw
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -i wstat.in > wstat.out 2> wstat.err
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -nband ${NBAND} -i wstat.in > wstat.out 2> wstat.err
wfreq: wstat
${PARA_PREFIX} ${BINDIR}/wfreq.x -nimage ${NIMAGE} -i wfreq.in > wfreq.out 2> wfreq.err
......
......@@ -15,7 +15,7 @@ pw:
${PARA_PREFIX_QE} ${BINDIR}/pw.x -i pw.in > pw.out 2> pw.err
wstat: pw
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -i wstat.in > wstat.out 2> wstat.err
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -nband ${NBAND} -i wstat.in > wstat.out 2> wstat.err
wfreq: wstat
${PARA_PREFIX} ${BINDIR}/wfreq.x -nimage ${NIMAGE} -i wfreq.in > wfreq.out 2> wfreq.err
......
......@@ -15,7 +15,7 @@ pw:
${PARA_PREFIX_QE} ${BINDIR}/pw.x -i pw.in > pw.out 2> pw.err
wstat: pw
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -i wstat.in > wstat.out 2> wstat.err
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -nband ${NBAND} -i wstat.in > wstat.out 2> wstat.err
wfreq: wstat
${PARA_PREFIX} ${BINDIR}/wfreq.x -nimage ${NIMAGE} -i wfreq.in > wfreq.out 2> wfreq.err
......
......@@ -15,7 +15,7 @@ pw:
${PARA_PREFIX_QE} ${BINDIR}/pw.x -i pw.in > pw.out 2> pw.err
wstat: pw
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -i wstat.in > wstat.out 2> wstat.err
${PARA_PREFIX} ${BINDIR}/wstat.x -nimage ${NIMAGE} -nband ${NBAND} -i wstat.in > wstat.out 2> wstat.err
wfreq: wstat
${PARA_PREFIX} ${BINDIR}/wfreq.x -nimage ${NIMAGE} -i wfreq.in > wfreq.out 2> wfreq.err
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment