Commit abb81083 authored by Nathanael Schaeffer @home's avatar Nathanael Schaeffer @home
Browse files

rotation functions parallelized + minor fixes

parent f9a8da4e
......@@ -37,6 +37,9 @@
shtns_cfg sht_data = NULL;
#ifdef _OPENMP
int omp_threads = 1; // multi-thread disabled by default.
#ifndef USE_LEGACY_FFTW3
#define OMP_FFTW
#endif
#else
#define omp_threads 1
#endif
......@@ -140,8 +143,12 @@ static void SH_rotK90(shtns_cfg shtns, complex double *Qlm, complex double *Rlm,
ntheta = ((lmax+2)>>1)*2;
q0 = malloc(2* sizeof(double)*(2*ntheta+2)*lmax);
memset(q0, 0, 2* sizeof(double)*(2*ntheta+2)*lmax); // zero out.
yl = malloc(2* sizeof(double)*(lmax+1));
dyl = yl + (lmax+1);
dyl = malloc(sizeof(double)*(lmax+1) * (1+shtns->nthreads));
yl = dyl + (lmax+1);
#ifdef OMP_FFTW
k = (lmax < 63) ? 1 : shtns->nthreads;
fftw_plan_with_nthreads(k);
#endif
// rotate around Z by dphi0
if (dphi0 != 0.0) {
......@@ -151,7 +158,13 @@ static void SH_rotK90(shtns_cfg shtns, complex double *Qlm, complex double *Rlm,
Rlm[0] = Qlm[0]; // l=0 is rotation invariant.
}
#pragma omp parallel private(k,m,l) firstprivate(yl) num_threads(shtns->nthreads)
{
#ifdef _OPENMP
yl += (lmax+1)*omp_get_thread_num();
#endif
// compute q(l) on the meridian phi=0 and phi=pi. (rotate around X)
#pragma omp for schedule(static)
for (k=0; k<ntheta/2; ++k) {
double cost= cos(M_PI*(k+0.5)/ntheta);
double sint_1 = 1.0/sqrt((1.0-cost)*(1.0+cost));
......@@ -187,6 +200,7 @@ static void SH_rotK90(shtns_cfg shtns, complex double *Qlm, complex double *Rlm,
}
}
}
}
q = (complex double*) q0;
ntheta*=2; nrembed = ntheta+2; ncembed = nrembed/2;
......@@ -215,7 +229,7 @@ static void SH_rotK90(shtns_cfg shtns, complex double *Qlm, complex double *Rlm,
Rlm[LiM(shtns, l,m)] = eimdp*q[m*2*lmax +2*(l-1)]/yl[l];
}
}
free(yl); free(q0);
free(dyl); free(q0);
}
......@@ -395,6 +409,9 @@ void SHqst_to_lat(shtns_cfg shtns, complex double *Qlm, complex double *Slm, com
if ((nphi != nphi_lat)||(ifft_lat == NULL)) {
if (ifft_lat != NULL) fftw_destroy_plan(ifft_lat);
#ifdef OMP_FFTW
fftw_plan_with_nthreads(1);
#endif
ifft_lat = fftw_plan_dft_c2r_1d(nphi, vrc, vr, FFTW_ESTIMATE);
nphi_lat = nphi;
}
......@@ -778,13 +795,11 @@ static void planFFT(shtns_cfg shtns, int layout, int on_the_fly)
if (NPHI <= 2*MMAX) shtns_runerr("the sampling condition Nphi > 2*Mmax is not met.");
#if _OPENMP
#ifndef USE_LEGACY_FFTW3
fftw_init_threads();
fftw_plan_with_nthreads(omp_threads);
if ((shtns->fftw_plan_mode == FFTW_EXHAUSTIVE) && (omp_threads > 1))
#ifdef OMP_FFTW
if ((shtns->fftw_plan_mode & (FFTW_EXHAUSTIVE | FFTW_PATIENT)) && (omp_threads > 1)) {
shtns->fftw_plan_mode = FFTW_PATIENT;
#endif
fftw_plan_with_nthreads(omp_threads);
} else fftw_plan_with_nthreads(shtns->nthreads);
#endif
shtns->fft = NULL; shtns->ifft = NULL;
......@@ -919,6 +934,11 @@ static void planDCT(shtns_cfg shtns)
fftw_iodim dims, hdims[2];
double Sh0[NLAT] SSE; // temp storage on the stack, aligned.
#ifdef OMP_FFTW
if ((shtns->fftw_plan_mode & (FFTW_EXHAUSTIVE | FFTW_PATIENT)) && (omp_threads > 1)) {
fftw_plan_with_nthreads(omp_threads);
} else fftw_plan_with_nthreads(1);
#endif
// Allocate dummy Spatial Fields.
Sh = (double *) VMALLOC((NPHI/2 +1) * NLAT*2 * sizeof(double));
......@@ -1279,6 +1299,9 @@ static void init_SH_dct(shtns_cfg shtns, int analysis)
if (vector) shtns->dykm_dct[im+1] = shtns->dykm_dct[im] + dsk;
}
#ifdef OMP_FFTW
fftw_plan_with_nthreads(1);
#endif
dct = fftw_plan_r2r_1d( 2*NLAT_2, Z, Z, FFTW_REDFT10, FFTW_ESTIMATE ); // quick and dirty dct.
idct = fftw_plan_r2r_1d( 2*NLAT_2, Z, Z, FFTW_REDFT01, FFTW_ESTIMATE ); // quick and dirty idct.
......@@ -1943,7 +1966,7 @@ void shtns_print_version() {
void shtns_print_cfg(shtns_cfg shtns)
{
printf("Lmax=%d, Mmax*Mres=%d, Mres=%d, Nlm=%d [",LMAX, MMAX*MRES, MRES, NLM);
printf("Lmax=%d, Mmax*Mres=%d, Mres=%d, Nlm=%d [%d threads, ",LMAX, MMAX*MRES, MRES, NLM, shtns->nthreads);
if (shtns->norm & SHT_REAL_NORM) printf("'real' norm, ");
if (shtns->norm & SHT_NO_CS_PHASE) printf("no Condon-Shortley phase, ");
if (SHT_NORM == sht_fourpi) printf("4.pi normalized]\n");
......@@ -2034,11 +2057,14 @@ shtns_cfg shtns_create(int lmax, int mmax, int mres, enum shtns_norm norm)
mpos_renorm = 0.5; // normalization for 'real' spherical harmonics.
shtns->mmax = mmax; shtns->mres = mres; shtns->lmax = lmax;
shtns->nlm = nlm_calc(LMAX, MMAX, MRES);
#if SHT_VERBOSE > 0
shtns_print_version();
printf(" "); shtns_print_cfg(shtns);
#endif
shtns->nlm = nlm_calc(lmax, mmax, mres);
shtns->nthreads = omp_threads;
if (omp_threads > mmax+1) shtns->nthreads = mmax+1; // limit the number of threads to mmax+1
#if SHT_VERBOSE > 0
shtns_print_version();
printf(" "); shtns_print_cfg(shtns);
if (shtns->nthreads > 1) printf(" => enabled %d OpenMP threads\n",shtns->nthreads);
#endif
s2 = sht_data; // check if some data can be shared ...
while(s2 != NULL) {
......@@ -2190,7 +2216,7 @@ void shtns_reset()
*/
int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int nl_order, int *nlat, int *nphi)
{
double t, latdir;
double t, mem;
int im,m;
int layout;
int nloop = 0;
......@@ -2198,6 +2224,7 @@ int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int
int on_the_fly = 0;
int quick_init = 0;
int vector = !(flags & SHT_SCALAR_ONLY);
int latdir = (flags & SHT_SOUTH_POLE_FIRST) ? -1 : 1; // choose latitudinal direction (change sign of ct)
int analys = 1;
#if _GCC_VEC_
......@@ -2211,16 +2238,9 @@ int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int
shtns->nspat = 0;
shtns->nlorder = nl_order;
shtns->mtr_dct = -1; // dct switched off completely.
latdir = (flags & SHT_SOUTH_POLE_FIRST) ? -1 : 1; // choose latitudinal direction (change sign of ct)
layout = flags & 0xFFFF00;
flags = flags & 255; // clear higher bits.
shtns->nthreads = omp_threads;
if (omp_threads > shtns->mmax+1) shtns->nthreads = shtns->mmax+1; // limit the number of threads to mmax+1
#if SHT_VERBOSE > 0
if (shtns->nthreads > 1) printf(" => enabled %d OpenMP threads\n",shtns->nthreads);
#endif
switch (flags) {
case sht_gauss_fly : flags = sht_gauss; on_the_fly = 1; break;
case sht_quick_init : flags = sht_gauss; quick_init = 1; break;
......@@ -2240,8 +2260,8 @@ int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int
} else *nlat = n_gauss;
}
t = sht_mem_size(shtns->lmax, shtns->mmax, shtns->mres, *nlat);
if (analys) t*=2; if (vector) t*=3;
mem = sht_mem_size(shtns->lmax, shtns->mmax, shtns->mres, *nlat);
t=mem; if (analys) t*=2; if (vector) t*=3;
#if SHT_VERBOSE > 1
printf("Memory required for precomputed matrices (estimate) : %.3f Mb\n",t);
#endif
......@@ -2262,6 +2282,7 @@ int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int
if (*nphi > 1024) shtns->fftw_plan_mode = FFTW_MEASURE;
} else {
shtns->fftw_plan_mode = FFTW_ESTIMATE;
if ((mem < 1.0) && (SHT_VERBOSE < 2)) shtns->nthreads = 1; // disable threads for small transforms (in quickinit mode).
if ((VSIZE2 >= 4) && (*nlat >= VSIZE2*4)) on_the_fly = 1; // with AVX, on-the-fly should be the default (faster).
if ((shtns->nthreads > 1) && (*nlat >= VSIZE2*16)) on_the_fly = 1; // force multi-thread transforms
}
......@@ -2303,12 +2324,12 @@ int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int
} else {
t = SHT_error(shtns, vector);
if (t > MIN_ACCURACY_DCT) {
#if SHT_VERBOSE > 0
#if SHT_VERBOSE > 0
printf(" !! Not enough accuracy (%.3g) => DCT disabled.\n",t);
#endif
#if SHT_VERBOSE < 2
#endif
#if SHT_VERBOSE < 2
Set_MTR_DCT(shtns, -1); // turn off DCT.
#endif
#endif
}
}
}
......@@ -2373,6 +2394,9 @@ int shtns_set_grid_auto(shtns_cfg shtns, enum shtns_type flags, double eps, int
#endif
}
// set_sht_fly(shtns, SHT_TYP_VAN);
#if SHT_VERBOSE > 1
if (omp_threads > 1) printf(" nthreads = %d\n",shtns->nthreads);
#endif
#if SHT_VERBOSE > 0
printf(" => SHTns is ready.\n");
#endif
......@@ -2421,8 +2445,13 @@ shtns_cfg shtns_init(enum shtns_type flags, int lmax, int mmax, int mres, int nl
int shtns_use_threads(int num_threads)
{
#ifdef _OPENMP
if (num_threads <= 0) num_threads = omp_get_num_procs();
int procs = omp_get_num_procs();
if (num_threads <= 0) num_threads = procs;
else if (num_threads > 4*procs) num_threads = 4*procs; // limit the number of threads
omp_threads = num_threads;
#endif
#ifdef OMP_FFTW
fftw_init_threads(); // enable threads for FFTW.
#endif
return omp_threads;
}
......@@ -2439,6 +2468,12 @@ int shtns_use_threads(int num_threads)
* see the \link SHT_example.f Fortran example \endlink for a simple usage of SHTns from Fortran language.
*/
//@{
/// Enable threads
void shtns_use_threads_(int *num_threads)
{
shtns_use_threads(*num_threads);
}
/// Initializes spherical harmonic transforms of given size using Gauss algorithm with default polar optimization.
void shtns_init_sh_gauss_(int *layout, int *lmax, int *mmax, int *mres, int *nlat, int *nphi)
......
......@@ -36,6 +36,13 @@
#define SUPARG2 , shtns->lmax
#endif
#ifdef _OPENMP
#ifndef SHT_AXISYM
#define ADD_OPENMP
#endif
#endif
/// \name scalar transforms
//@{
......@@ -77,7 +84,7 @@
#include "spat_to_SH_fly.c"
#include "SH_to_spat_fly.c"
#undef NWAY
#ifdef _OPENMP
#ifdef ADD_OPENMP
#define NWAY 1
#include "spat_to_SHst_omp.c"
#include "SHst_to_spat_omp.c"
......@@ -167,7 +174,7 @@ void GEN(spat_to_SHsphtor,SUFFIX)(shtns_cfg shtns, double *Vt, double *Vp, compl
#include "SHs_to_spat_fly.c"
#include "SHt_to_spat_fly.c"
#undef NWAY
#ifdef _OPENMP
#ifdef ADD_OPENMP
#define NWAY 1
#include "SHs_to_spat_omp.c"
#include "SHt_to_spat_omp.c"
......@@ -260,7 +267,7 @@ void GEN(SHtor_to_spat,SUFFIX)(shtns_cfg shtns, complex double *Tlm, double *Vp
#include "spat_to_SHqst_fly.c"
#include "SHqst_to_spat_fly.c"
#undef NWAY
#ifdef _OPENMP
#ifdef ADD_OPENMP
#define NWAY 1
#include "spat_to_SHqst_omp.c"
#include "SHqst_to_spat_omp.c"
......@@ -351,7 +358,7 @@ void* GEN(sht_array, SUFFIX)[SHT_NALG][SHT_NTYP] = {
NULL, NULL, NULL, NULL },
/* fly8 */ { GEN(SH_to_spat_fly8, SUFFIX), GEN(spat_to_SH_fly8, SUFFIX), NULL, NULL,
NULL, NULL, NULL, NULL },
#ifdef _OPENMP
#ifdef ADD_OPENMP
/* omp1 */ { NULL, NULL, GEN(SHsphtor_to_spat_omp1, SUFFIX), GEN(spat_to_SHsphtor_omp1, SUFFIX),
GEN(SHsph_to_spat_omp1, SUFFIX), GEN(SHtor_to_spat_omp1, SUFFIX), GEN(SHqst_to_spat_omp1, SUFFIX), GEN(spat_to_SHqst_omp1, SUFFIX) },
/* omp2 */ { GEN(SH_to_spat_omp2, SUFFIX), GEN(spat_to_SH_omp2, SUFFIX), GEN(SHsphtor_to_spat_omp2, SUFFIX), GEN(spat_to_SHsphtor_omp2, SUFFIX),
......@@ -471,3 +478,4 @@ void GENF(spat_to_qst,SUFFIX)(double *Vr, double *Vt, double *Vp, complex double
#undef SHT_AXISYM
#undef SHT_VAR_LTR
#undef IVAR
#undef ADD_OPENMP
......@@ -49,6 +49,7 @@ int main()
lmax = 5; nlat = 32;
mmax = 3; nphi = 10;
mres = 1;
shtns_use_threads(0); // enable multi-threaded transforms (if supported).
shtns = shtns_init( sht_gauss, lmax, mmax, mres, nlat, nphi );
// shtns = shtns_create(lmax, mmax, mres, sht_orthonormal | SHT_REAL_NORM);
// shtns_set_grid(shtns, sht_gauss, 0.0, nlat, nphi);
......
......@@ -44,6 +44,9 @@
call shtns_calc_nlm(nlm, lmax, mmax, mres)
print*,'NLM=',nlm
! enable multi-threaded transform (OpenMP) if supported.
call shtns_use_threads(0)
! init SHT. SHT_PHI_CONTIGUOUS is defined in 'shtns.f'
layout = SHT_PHI_CONTIGUOUS
call shtns_init_sh_gauss(layout, lmax, mmax, mres, nlat, nphi)
......
......@@ -34,7 +34,6 @@
#include "shtns.h"
#ifdef _OPENMP
#warning "using OpenMP"
#include <omp.h>
#endif
......
......@@ -79,6 +79,7 @@ int main(int argc, char *argv[])
NLAT=64;
NPHI=128;
shtns_use_threads(0);
shtns = shtns_create(LMAX, MMAX, MRES, shtnorm);
NLM = shtns->nlm;
shtns_set_grid_auto(shtns, sht_quick_init, 0.0, 1e-10, &NLAT, &NPHI);
......
......@@ -3,45 +3,38 @@
id=`hg id`
log="test_suite.log"
function test1 {
echo $1
echo "---" >> $log
echo "*** $1 *** " >> $log
$1 > tmp.out
cat tmp.out | grep ERROR
cat tmp.out | grep -i nan
cat tmp.out >> $log
}
echo "beginning test suite for $id" > $log
# first, do a huge transform :
c="./time_SHT 2047 -mres=15 -quickinit -iter=1"
echo $c
echo "---" >> $log
echo "*** $c *** " >> $log
$c > tmp.out
cat tmp.out | grep ERROR
cat tmp.out >> $log
test1 "./time_SHT 2047 -mres=15 -quickinit -iter=1"
# even bigger :
c="./time_SHT 7975 -mres=145 -quickinit -iter=1"
echo $c
echo "---" >> $log
echo "*** $c *** " >> $log
$c > tmp.out
cat tmp.out | grep ERROR
cat tmp.out >> $log
test1 "./time_SHT 7975 -mres=145 -quickinit -iter=1"
# without threads
test1 "./time_SHT 2047 -mres=15 -quickinit -iter=1 -nth=1"
for switch in "" "-oop" "-transpose" "-schmidt" "-4pi"
do
for mode in "-quickinit" "-gauss" "-reg" "-fly"
for mode in "-quickinit" "-gauss" "-reg" "-fly" "-gauss -nth=1"
do
for lmax in 1 2 3 4 11 12 13 14 31 32 33 34 121 122 123 124
do
for mmax in 0 1 $lmax
do
c="./time_SHT $lmax -mmax=$mmax $mode $switch -iter=1"
echo $c
echo "---" >> $log
echo "*** $c *** " >> $log
$c > tmp.out
cat tmp.out | grep ERROR
cat tmp.out >> $log
test1 "./time_SHT $lmax -mmax=$mmax $mode $switch -iter=1"
done
done
done
done
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment