PROGRAM reshape_XIOS_output

!=======================================================================
! Purpose: Read XIOS NetCDF files and convert them onto the PCM grid.
!          XIOS longitudes run from -180 to +180 (exclusive). So we append
!          the first longitude value again at the end in the output to
!          complete the grid. Done for the two PCM years.
!
! Authors: RV & LL (original), JBC (optimized)
!=======================================================================

use netcdf
use version_info_mod, only: print_version_info

implicit none

! Variables for NetCDF I/O and bookkeeping
integer                            :: state
integer                            :: ncid_in, ncid_out
integer                            :: ndims, nvars, nGlobalAtts, unlimDimID
integer, allocatable, dimension(:) :: dimids_in, varids_in
integer, allocatable, dimension(:) :: dimids_out, varids_out

! Store each input dimension name and length
character(30), allocatable, dimension(:) :: dimNames
integer,       allocatable, dimension(:) :: dimLens

! Which input‐index corresponds to lon/lat/time/soil (–1 if not present)
integer :: idx_lon_in  = -1
integer :: idx_lat_in  = -1
integer :: idx_time_in = -1
integer :: idx_soil_in = -1

! Lengths of key dims (input), plus output lon length
integer :: len_lon_in, len_lat_in, len_time_in, len_soil_in
integer :: len_lon_out

! Loop and variable bookkeeping
integer                            :: i, j, k
integer                            :: numDimsVar, numAttsVar
character(100)                     :: varName, arg
integer                            :: xtypeVar
integer, allocatable, dimension(:) :: dimids_var_in

! Buffers for reading/writing when first‐dim = lon (max‐sized)
real, allocatable, dimension(:)       :: buf1D_in, buf1D_out
real, allocatable, dimension(:,:)     :: buf2D_in, buf2D_out
real, allocatable, dimension(:,:,:)   :: buf3D_in, buf3D_out
real, allocatable, dimension(:,:,:,:) :: buf4D_in, buf4D_out

! Temporaries for "non‐lon‐first" variables
real, allocatable, dimension(:)       :: tmp1D
real, allocatable, dimension(:,:)     :: tmp2D
real, allocatable, dimension(:,:,:)   :: tmp3D
real, allocatable, dimension(:,:,:,:) :: tmp4D

! Temporaries for dimension inquiries
integer       :: thisLen
integer       :: len1, len2, len3, len4
integer       :: lenDim2, lenDim3, lenDim4
character(30) :: tmpDimName

logical :: uses_lon_first

! For looping over two "years"
integer      :: numyear
character(4) :: str

! For deleting existing output
integer :: cstat
logical :: exists

! CODE
! Handle command‐line argument "version"
if (command_argument_count() > 0) then ! Get the number of command-line arguments
    call get_command_argument(1,arg) ! Read the argument given to the program
    select case (trim(adjustl(arg)))
        case('version')
            call print_version_info()
            stop
        case default
            error stop 'The argument given to the program is unknown!'
    end select
endif

! Main loop: two PCM years
do numyear = 1,2
    write(str,'(I1.1)') numyear
    write(*,*) "> Reshaping variables from ""data2reshape_Y"//trim(str)//".nc""..."

    ! Open input file (read‐only)
    state = nf90_open("data2reshape_Y"//trim(str)//".nc",mode = nf90_nowrite,ncid = ncid_in)
    if (state /= nf90_noerr) call handle_err(state)

    ! If output exists, delete it
    inquire(file = "data_PCM_Y"//trim(str)//".nc",exist = exists)
    if (exists) then
        call execute_command_line("rm data_PCM_Y"//trim(str)//".nc",cmdstat = cstat)
        if (cstat > 0) then
            error stop 'Command exection failed!'
        else if (cstat < 0) then
            error stop 'Command execution not supported!'
        endif
    endif

    ! Create output file in define mode
    state = nf90_create("data_PCM_Y"//trim(str)//".nc",cmode = or(nf90_noclobber,nf90_64bit_offset),ncid = ncid_out)
    if (state /= nf90_noerr) call handle_err(state)

    ! Inquire input for dims, vars, global atts, unlimited dim ID
    state = nf90_inquire(ncid_in,ndims,nvars,nGlobalAtts,unlimDimID)
    if (state /= nf90_noerr) call handle_err(state)

    ! Allocate arrays for dim IDs, var IDs, names, lengths
    allocate(dimids_in(ndims),varids_in(nvars),dimids_out(ndims),varids_out(nvars),dimNames(ndims),dimLens(ndims))

    ! Get the dimension IDs and then query each for its name and length
    state = nf90_inq_dimids(ncid_in,ndims,dimids_in,unlimDimID)
    if (state /= nf90_noerr) call handle_err(state)

    do i = 1,ndims
        state = nf90_inquire_dimension(ncid_in,dimids_in(i),dimNames(i),dimLens(i))
        if (state /= nf90_noerr) call handle_err(state)

        select case (trim(dimNames(i)))
            case ("lon","longitude")
                idx_lon_in = i
                len_lon_in = dimLens(i)
            case ("lat","latitude")
                idx_lat_in = i
                len_lat_in = dimLens(i)
            case ("time_counter","Time")
                idx_time_in = i
                len_time_in = dimLens(i)
            case ("soil_layers","subsurface_layers")
                idx_soil_in = i
                len_soil_in = dimLens(i)
            case default
                ! nothing special
        end select

        ! Define the same dimension in the output, except lon becoming (len_lon_in + 1)
        if (i == idx_lon_in) then
            len_lon_out = len_lon_in + 1
            state = nf90_def_dim(ncid_out,trim(dimNames(i)),len_lon_out,dimids_out(i))
        else
            state = nf90_def_dim(ncid_out,trim(dimNames(i)),dimLens(i),dimids_out(i))
        endif
        if (state /= nf90_noerr) call handle_err(state)
    enddo

    ! Ensure mandatory dims exist
    if (idx_lon_in < 0 .or. idx_lat_in < 0) error stop "Input is missing mandatory 'lon' or 'lat' dimension."
    if (idx_time_in < 0) len_time_in = 1
    if (idx_soil_in < 0) len_soil_in = 1

    ! Allocate only the "lon‐first" buffers (max‐sized) once
    allocate(buf1D_in(len_lon_in),buf1D_out(len_lon_out))
    allocate(buf2D_in(len_lon_in,len_lat_in),buf2D_out(len_lon_out, len_lat_in))
    allocate(buf3D_in(len_lon_in,len_lat_in,len_time_in),buf3D_out(len_lon_out,len_lat_in,len_time_in))
    allocate(buf4D_in(len_lon_in,len_lat_in,len_soil_in,len_time_in),buf4D_out(len_lon_out,len_lat_in,len_soil_in,len_time_in))

    ! Get all variable IDs
    state = nf90_inq_varids(ncid_in,nvars,varids_in)
    if (state /= nf90_noerr) call handle_err(state)

    ! Loop over each variable to define it in the output
    do i = 1,nvars
        ! Inquire name, xtype, ndims, natts
        state = nf90_inquire_variable(ncid_in,varids_in(i),name = varName,xtype = xtypeVar,ndims = numDimsVar,natts = numAttsVar)
        if (state /= nf90_noerr) call handle_err(state)
        write(*,*) 'Treatment of '//varName

        allocate(dimids_var_in(numDimsVar))
        state = nf90_inquire_variable(ncid_in,varids_in(i),name = varName,xtype = xtypeVar,ndims = numDimsVar,dimids = dimids_var_in,natts = numAttsVar)
        if (state /= nf90_noerr) call handle_err(state)

        ! Detect if this variable first dimension is "lon"
        if (numDimsVar >= 1 .and. dimids_var_in(1) == dimids_in(idx_lon_in)) then
            uses_lon_first = .true.
        else
            uses_lon_first = .false.
        endif

        ! Build the output‐dimids list: replace the first dim with the output lon if needed
        if (uses_lon_first) dimids_var_in(1) = dimids_out(idx_lon_in)
        do j = 2,numDimsVar
        ! Map each subsequent input dim to its output dim
            do k = 1,ndims
                if (dimids_var_in(j) == dimids_in(k)) then
                    dimids_var_in(j) = dimids_out(k)
                    exit
                endif
            enddo
        enddo

        ! Define this variable (same name, same xtype, but new dimids)
        state = nf90_def_var(ncid_out,trim(varName),xtypeVar,dimids_var_in,varids_out(i))
        if (state /= nf90_noerr) call handle_err(state)

        deallocate(dimids_var_in)
    enddo

    ! Done defining all dims and vars exit define mode exactly once
    state = nf90_enddef(ncid_out)
    if (state /= nf90_noerr) call handle_err(state)

    ! Loop over each variable to read from input and write to output
    do i = 1,nvars
        ! Re‐inquire metadata so we know dimids_var_in and numDimsVar
        state = nf90_inquire_variable(ncid_in,varids_in(i),name = varName,xtype = xtypeVar,ndims = numDimsVar,natts = numAttsVar)
        if (state /= nf90_noerr) call handle_err(state)

        allocate(dimids_var_in(numDimsVar))
        state = nf90_inquire_variable(ncid_in, varids_in(i),name = varName,xtype = xtypeVar,ndims = numDimsVar,dimids = dimids_var_in,natts = numAttsVar)
        if (state /= nf90_noerr) call handle_err(state)

        ! Detect again if first dim = lon
        if (numDimsVar >= 1 .and. dimids_var_in(1) == dimids_in(idx_lon_in)) then
            uses_lon_first = .true.
        else
            uses_lon_first = .false.
        endif

        select case (numDimsVar)
            case (1)
                if (uses_lon_first) then
                    ! 1D lon sequence: read len_lon_in, extend to len_lon_out
                    state = nf90_get_var(ncid_in,varids_in(i),buf1D_in)
                    if (state /= nf90_noerr) call handle_err(state)

                    buf1D_out(1:len_lon_in) = buf1D_in(1:len_lon_in)
                    buf1D_out(len_lon_out)  = buf1D_in(1)  ! repeat first lon at end

                    state = nf90_put_var(ncid_out,varids_out(i),buf1D_out)
                    if (state /= nf90_noerr) call handle_err(state)

                else
                    ! Some other 1D (e.g. lat or time). Allocate exact 1D temp:
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(1),tmpDimName, thisLen)
                    if (state /= nf90_noerr) call handle_err(state)

                    allocate(tmp1D(thisLen))
                    state = nf90_get_var(ncid_in,varids_in(i),tmp1D(1:thisLen))
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_put_var(ncid_out,varids_out(i),tmp1D(1:thisLen))
                    if (state /= nf90_noerr) call handle_err(state)

                    deallocate(tmp1D)
                endif

            case (2)
                if (uses_lon_first) then
                    ! 2D with first dim = lon (len_lon_in × lenDim2)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(2),tmpDimName,lenDim2)
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_get_var(ncid_in,varids_in(i),buf2D_in(1:len_lon_in,1:lenDim2))
                    if (state /= nf90_noerr) call handle_err(state)

                    buf2D_out(1:len_lon_in,1:lenDim2) = buf2D_in(1:len_lon_in,1:lenDim2)
                    buf2D_out(len_lon_out,1:lenDim2) = buf2D_in(1,1:lenDim2)

                    state = nf90_put_var(ncid_out,varids_out(i),buf2D_out(1:len_lon_out,1:lenDim2))
                    if (state /= nf90_noerr) call handle_err(state)

                else
                    ! Some other 2D (no lon‐extension). Allocate exact 2D temp:
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(1),tmpDimName,len1)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in, dimids_var_in(2),tmpDimName,len2)
                    if (state /= nf90_noerr) call handle_err(state)

                    allocate(tmp2D(len1,len2))
                    state = nf90_get_var(ncid_in,varids_in(i),tmp2D(1:len1,1:len2))
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_put_var(ncid_out, varids_out(i), tmp2D(1:len1,1:len2))
                    if (state /= nf90_noerr) call handle_err(state)

                    deallocate(tmp2D)
                endif

            case (3)
                if (uses_lon_first) then
                    ! 3D with first dim = lon (len_lon_in × lenDim2 × lenDim3)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(2),tmpDimName,lenDim2)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(3),tmpDimName,lenDim3)
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_get_var(ncid_in,varids_in(i),buf3D_in(1:len_lon_in,1:lenDim2,1:lenDim3))
                    if (state /= nf90_noerr) call handle_err(state)

                    buf3D_out(1:len_lon_in,1:lenDim2,1:lenDim3) = buf3D_in(1:len_lon_in,1:lenDim2,1:lenDim3)
                    buf3D_out(len_lon_out,1:lenDim2,1:lenDim3) = buf3D_in(1,1:lenDim2,1:lenDim3)

                    state = nf90_put_var(ncid_out,varids_out(i),buf3D_out(1:len_lon_out,1:lenDim2,1:lenDim3))
                    if (state /= nf90_noerr) call handle_err(state)

                else
                    ! Some other 3D (no lon‐extension). Allocate exact 3D temp:
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(1),tmpDimName,len1)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(2),tmpDimName,len2)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(3),tmpDimName,len3)
                    if (state /= nf90_noerr) call handle_err(state)

                    allocate(tmp3D(len1,len2,len3))
                    state = nf90_get_var(ncid_in,varids_in(i),tmp3D(1:len1,1:len2,1:len3))
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_put_var(ncid_out,varids_out(i),tmp3D(1:len1,1:len2,1:len3))
                    if (state /= nf90_noerr) call handle_err(state)

                    deallocate(tmp3D)
                endif

            case (4)
                if (uses_lon_first) then ! 4D with first dim = lon (len_lon_in × lenDim2 × lenDim3 × lenDim4)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(2),tmpDimName,lenDim2)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(3),tmpDimName,lenDim3)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in,dimids_var_in(4),tmpDimName,lenDim4)
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_get_var(ncid_in,varids_in(i),buf4D_in(1:len_lon_in,1:lenDim2,1:lenDim3,1:lenDim4))
                    if (state /= nf90_noerr) call handle_err(state)

                    buf4D_out(1:len_lon_in,1:lenDim2,1:lenDim3,1:lenDim4) = buf4D_in(1:len_lon_in, 1:lenDim2,1:lenDim3,1:lenDim4)
                    buf4D_out(len_lon_out,1:lenDim2,1:lenDim3,1:lenDim4) = buf4D_in(1,1:lenDim2,1:lenDim3,1:lenDim4)

                    state = nf90_put_var(ncid_out,varids_out(i),buf4D_out(1:len_lon_out,1:lenDim2,1:lenDim3,1:lenDim4))
                    if (state /= nf90_noerr) call handle_err(state)

                else ! Some other 4D (no lon‐extension). Allocate exact 4D temp:
                    state = nf90_inquire_dimension(ncid_in, dimids_var_in(1),tmpDimName,len1)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in, dimids_var_in(2),tmpDimName,len2)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in, dimids_var_in(3),tmpDimName,len3)
                    if (state /= nf90_noerr) call handle_err(state)
                    state = nf90_inquire_dimension(ncid_in, dimids_var_in(4),tmpDimName,len4)
                    if (state /= nf90_noerr) call handle_err(state)

                    allocate(tmp4D(len1,len2,len3,len4))
                    state = nf90_get_var(ncid_in,varids_in(i),tmp4D(1:len1,1:len2,1:len3,1:len4))
                    if (state /= nf90_noerr) call handle_err(state)

                    state = nf90_put_var(ncid_out,varids_out(i),tmp4D(1:len1,1:len2,1:len3,1:len4))
                    if (state /= nf90_noerr) call handle_err(state)

                    deallocate(tmp4D)
                endif

            case default
                cycle ! Skip variables with 0 dims
        end select

        deallocate(dimids_var_in)
    enddo

    ! Close both NetCDF files
    state = nf90_close(ncid_in)
    if (state /= nf90_noerr) call handle_err(state)
    state = nf90_close(ncid_out)
    if (state /= nf90_noerr) call handle_err(state)

    ! Deallocate everything
    deallocate(dimids_in,dimids_out,varids_in,varids_out,dimNames,dimLens)
    deallocate(buf1D_in,buf1D_out,buf2D_in,buf2D_out,buf3D_in,buf3D_out,buf4D_in,buf4D_out)

    write(*,*) "> ""data2reshape_Y"//trim(str)//".nc"" processed!"
enddo

END PROGRAM reshape_XIOS_output
