#!/bin/bash
# Adapted A. Barral 07/2024 from https://dci.dci-gitlab.cines.fr/webextranet/porting_optimization/detailed_binding_script.html#porting-optimization-detailed-binding-script
# Used within exclusive nodes as srun --cpu-bind=none --mem-bind=none -- ./set_binding.sh <executable> <args>
# Note: On Adastra, OMP_PLACES & OMP_PROC_BIND seems to have no effect (probably the default assignment is the same), but I left it there just in case for other configurations.

set -eu

function create_affinity_numactl() {
    local num_tasks=$1  # Number of MPI tasks
    local num_threads=$2  # Number of OMP threads per task

    AFFINITY_NUMACTL=()
    OMP_PLACES_ARR=()
    local start_cpu=0
    local total_strands thread_per_core
    total_strands=392  # "$(lscpu | grep "^CPU(s):" | awk '{print $2}')"  # Total physical threads in the system
    thread_per_core="$(lscpu | grep "^Thread(s) per core:" | awk '{print $4}')"
    local total_cpus=$(( total_strands / thread_per_core ))  # Total physical cores in the system
    num_smt_used=$(( (num_tasks*num_threads+total_cpus-1)/total_cpus ))  # ceil(total_threads_required/total_cpus)

    # Parameter range check
    if [[ $(( num_tasks * num_threads )) -gt $(( total_strands )) ]]; then
        echo "STOP: requesting more CPUs than available on the system!"; exit 1
    fi
    if [[ $(( num_threads % num_smt_used )) -ne 0 ]]; then
        echo "STOP: OMP threads number ($num_threads) must be a multiple of $num_smt_used (using $num_smt_used out of $thread_per_core SMT threads per core)."; exit 1
    fi

    for (( task=0; task<num_tasks; task++ )); do
        local range=""
        local range_omp=""
        for (( i_smt=0; i_smt<num_smt_used; i_smt++ )); do
            local smt_start_cpu=$((start_cpu + total_cpus * i_smt))
            local smt_end_cpu=$((smt_start_cpu + num_threads / num_smt_used - 1))
            range+=",${smt_start_cpu}-${smt_end_cpu}"
            range_omp+=",{${smt_start_cpu}}:$((num_threads / num_smt_used)):1"
        done
        range=$(echo "$range" | cut -c 2-)
        range_omp=$(echo "$range_omp" | cut -c 2-)
        AFFINITY_NUMACTL+=("$range")
        OMP_PLACES_ARR+=("$range_omp")
        start_cpu=$((start_cpu + num_threads / num_smt_used))
    done
}

create_affinity_numactl "$SLURM_NTASKS_PER_NODE" "$OMP_NUM_THREADS"

# Modulo arithmetic eases some corner use cases.
LOCAL_RANK_INDEX="${SLURM_LOCALID}"
CPU_SET="${AFFINITY_NUMACTL[$((LOCAL_RANK_INDEX % ${#AFFINITY_NUMACTL[@]}))]}"
OMP_PLACES="${OMP_PLACES_ARR[$((LOCAL_RANK_INDEX % ${#OMP_PLACES_ARR[@]}))]}"

if [[ $LOCAL_RANK_INDEX = 0 ]]; then
  echo "[$(hostname)] Number of used SMT: $num_smt_used"
  echo "[$(hostname)] AFFINITY_NUMACTL:" "${AFFINITY_NUMACTL[@]}"
  echo "[$(hostname)] OMP_PLACES:" "${OMP_PLACES_ARR[@]}"
fi
export OMP_PLACES="$OMP_PLACES"
export OMP_PROC_BIND=CLOSE
echo "[$(hostname)] Starting local rank ${LOCAL_RANK_INDEX} with: 'numactl --localalloc --physcpubind=${CPU_SET} --' | OMP_PLACES=$OMP_PLACES"
exec numactl --localalloc --physcpubind="${CPU_SET}" -- "${@}"
