#!/bin/bash

### To manage cpu binding on Adastra, add this binding script to the run directory
### In your job script, adapt srun including the following arguments:
### srun --ntasks-per-node=${SLURM_NTASKS_PER_NODE} --cpu-bind=none --mem-bind=none --label -- ./adastra_cpu_binding.sh ./gcm_64x48x54_phymars_para.e > gcm.out 2>&1

set -eu

LOCAL_RANK_INDEX="${SLURM_LOCALID}"
LOCAL_RANK_COUNT="${SLURM_NTASKS_PER_NODE}"
LOCAL_NODE_COUNT="${SLURM_NNODES}"
LOCAL_OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK}"

echo "Adjust binding to ${LOCAL_NODE_COUNT} * ${LOCAL_RANK_COUNT} MPI tasks and ${LOCAL_OMP_NUM_THREADS} OpenMP threads"

function Adastra_MI250_8TasksWith8ThreadsAnd1GPU() {
    AFFINITY_NUMACTL=('48-55' '56-63' '16-23' '24-31' '0-7' '8-15' '32-39' '40-47')
    AFFINITY_GPU=('0' '1' '2' '3' '4' '5' '6' '7')
    export MPICH_OFI_NIC_POLICY=NUMA
}

function Adastra_GENOA_24TasksWith1Threads() {
    AFFINITY_NUMACTL=('0' '1' '2' '3' '4' '5' '6' '7' '8' '9' '10' '11' '12' '13' '14' '15' '16' '17' '18' '19' '20' '21' '22' '23')
}

function Adastra_GENOA_24TasksWith2Threads() {
    AFFINITY_NUMACTL=('0-1' '2-3' '4-5' '6-7' '8-9' '10-11' '12-13' '14-15' '16-17' '18-19' '20-21' '22-23' '24-25' '26-27' '28-29' '30-31' '32-33' '34-35' '36-37' '38-39' '40-41' '42-43' '44-45' '46-47')
}

function Adastra_GENOA_24TasksWith4Threads() {
    AFFINITY_NUMACTL=('0-3' '4-7' '8-11' '12-15' '16-19' '20-23' '24-27' '28-31' '32-35' '36-39' '40-43' '44-47' '48-51' '52-55' '56-59' '60-63' '64-67' '68-71' '72-75' '76-79' '80-83' '84-87' '88-91' '92-95')
}

function Adastra_GENOA_24TasksWith8Threads() {
    AFFINITY_NUMACTL=('0-7' '8-15' '16-23' '24-31' '32-39' '40-47' '48-55' '56-63' '64-71' '72-79' '80-87' '88-95' '96-103' '104-111' '112-119' '120-127' '128-135' '136-143' '144-151' '152-159' '160-167' '168-175' '176-183' '184-191')
}

function Adastra_GENOA_48TasksWith4Threads() {
    AFFINITY_NUMACTL=('0-3' '4-7' '8-11' '12-15' '16-19' '20-23' '24-27' '28-31' '32-35' '36-39' '40-43' '44-47' '48-51' '52-55' '56-59' '60-63' '64-67' '68-71' '72-75' '76-79' '80-83' '84-87' '88-91' '92-95' '96-99' '100-103' '104-107' '108-111' '112-115' '116-119' '120-123' '124-127' '128-131' '132-135' '136-139' '140-143' '144-147' '148-151' '152-155' '156-159' '160-163' '164-167' '168-171' '172-175' '176-179' '180-183' '184-187' '188-191')
}

function Adastra_GENOA_24TasksWith16Threads() {
    # Requires SMT to be enabled.
    AFFINITY_NUMACTL=('0-15' '16-31' '32-47' '48-63' '64-79' '80-95' '96-111' '112-127' '128-143' '144-159' '160-175' '176-191' '192-207' '208-223' '224-239' '240-255' '256-271' '272-287' '288-303' '304-319' '320-335' '336-351' '352-367' '368-383')
}

function Adastra_GENOA_12TasksWith16Threads() {
    AFFINITY_NUMACTL=('0-15' '16-31' '32-47' '48-63' '64-79' '80-95' '96-111' '112-127' '128-143' '144-159' '160-175' '176-191')
}

if [ "${LOCAL_RANK_COUNT}" -eq 12 ]; then
	if [ "${LOCAL_OMP_NUM_THREADS}" -eq 16 ]; then
                Adastra_GENOA_12TasksWith16Threads
	else
                echo "Could not find a function for the given MPI/OMP combination"
                exit 1
        fi
elif [ "${LOCAL_RANK_COUNT}" -eq 24 ]; then
	if [ "${LOCAL_OMP_NUM_THREADS}" -eq 1 ]; then
		Adastra_GENOA_24TasksWith1Threads
	elif [ "${LOCAL_OMP_NUM_THREADS}" -eq 2 ]; then
		Adastra_GENOA_24TasksWith2Threads
	elif [ "${LOCAL_OMP_NUM_THREADS}" -eq 4 ]; then
		Adastra_GENOA_24TasksWith4Threads
	elif [ "${LOCAL_OMP_NUM_THREADS}" -eq 8 ]; then
                Adastra_GENOA_24TasksWith8Threads
	elif [ "${LOCAL_OMP_NUM_THREADS}" -eq 16 ]; then
		Adastra_GENOA_24TasksWith16Threads
	else
		echo "Could not find a function for the given MPI/OMP combination"
		exit 1
	fi 
elif [ "${LOCAL_RANK_COUNT}" -eq 48 ]; then
	if [ "${LOCAL_OMP_NUM_THREADS}" -eq 4 ]; then
		Adastra_GENOA_48TasksWith4Threads
	else
		echo "Could not find a function for the given MPI/OMP combination"
                exit 1
	fi
else
	echo "Could not find a binding function for the given MPI/OMP combination"
	exit 1
fi

CPU_SET="${AFFINITY_NUMACTL[$((${LOCAL_RANK_INDEX} % ${#AFFINITY_NUMACTL[@]}))]}"
if [ ! -z ${AFFINITY_GPU+x} ]; then
    GPU_SET="${AFFINITY_GPU[$((${LOCAL_RANK_INDEX} % ${#AFFINITY_GPU[@]}))]}"
    export ROCR_VISIBLE_DEVICES="${GPU_SET}"
fi
exec numactl --localalloc --physcpubind="${CPU_SET}" -- "${@}"
