### script to download opacities from DACE
### adapted from https://github.com/bmorris3/shone/tree/refs/heads/opacity-downloads

import logging
import warnings
from functools import cached_property
import os
import tarfile
import shutil

import numpy as np
import xarray as xr
from astropy.table import Table
from dace_query.opacity import Molecule, Atom

from chemistry import species_name_to_common_isotopologue_name

interp_kwargs = dict(
    method='nearest',
    kwargs=dict(fill_value="extrapolate")
)

__all__ = [
    'download_molecule',
    'download_atom'
]

class AvailableOpacities:
    @cached_property
    def atomic(self):
        return get_atomic_database()

    @cached_property
    def molecular(self):
        return get_molecular_database()

    def get_atomic_database_entry(self, atom, charge, line_list):
        table = self.atomic
        return table[(
            (table['atom'] == atom) &
            (table['line_list'] == line_list) &
            (table['charge'] == charge)
        )]

    def get_molecular_database_entry(self, isotopologue, line_list):
        table = self.molecular
        return table[(
            (table['isotopologue'] == isotopologue) &
            (table['line_list'] == line_list)
        )]

    def get_molecular_line_lists(self, isotopologue):
        table = self.molecular
        return set(table[table['isotopologue'] == isotopologue]['line_list'])

    def get_atomic_line_lists(self, atom):
        table = self.atomic
        return set(table[table['atom'] == atom]['line_list'])

    def get_atomic_latest_version(self, atom, charge, line_list):
        table = self.atomic
        matches = table[(
            (table['atom'] == atom) &
            (table['line_list'] == line_list) &
            (table['charge'] == charge)
        )]
        return max(set(matches['version']))

    def get_molecular_latest_version(self, isotopologue, line_list):
        table = self.molecular
        matches = table[(
            (table['isotopologue'] == isotopologue) &
            (table['line_list'] == line_list)
        )]
        return max(set(matches['version']))

    def get_atomic_pT_range(self, atom, charge, line_list):
        table = self.get_atomic_database_entry(atom, charge, line_list)
        temperature_range = (
            int(table['temp_min_k'][0]),
            int(table['temp_max_k'][0])
        )
        pressure_range = (
            float(table['pressure_min_exponent_b'][0]),
            float(table['pressure_max_exponent_b'][0])
        )
        return temperature_range, pressure_range

    def get_molecular_pT_range(self, isotopologue, line_list):
        table = self.get_molecular_database_entry(isotopologue, line_list)
        temperature_range = (
            int(table['temp_min_k'][0]),
            int(table['temp_max_k'][0])
        )
        pressure_range = (
            float(table['pressure_min_exponent_b'][0]),
            float(table['pressure_max_exponent_b'][0])
        )
        return temperature_range, pressure_range


available_opacities = AvailableOpacities()


def get_atomic_database():
    db = Atom.query_database()
    table = Table(db)
    table.add_index('atom')
    return table


def get_molecular_database():
    db = Molecule.query_database()
    table = Table(db)
    table.add_index('isotopologue')
    return table

class DownloadDace:
    def __init__(self,
                 name='48Ti-16O',
                 line_list='Toto',
                 temperature_range=None,
                 pressure_range=None,
                 version=1,
                 output_dir="tmp",
                 force_download = False,
    ):
        self.output_dir = output_dir
        self.name = name
        self.line_list = line_list
        self.temperature_range = temperature_range
        self.pressure_range = pressure_range
        self.version = version
        self.force_download = force_download

    def dace_download_molecule(self):
        os.makedirs(self.output_dir, exist_ok=True)
        archive_name = self.name + '__' + self.line_list + '.tar.gz'
        path = os.path.join(self.output_dir, archive_name)
        if not self.force_download and os.path.exists(path):
            print("Molecule already downloaded")
            return path
        Molecule.download(
            self.name,
            self.line_list,
            float(self.version),
            self.temperature_range,
            self.pressure_range,
            output_directory=self.output_dir,
            output_filename=archive_name
        )

        return path


    def dace_download_atom(self, charge=0):
        os.makedirs(self.output_dir, exist_ok=True)
        archive_name = self.name + '__' + self.line_list + '.tar.gz'
        Atom.download(
            self.name, charge,
            self.line_list, float(self.version),
            self.temperature_range,
            self.pressure_range,
            output_directory=self.output_dir,
            output_filename=archive_name
        )
        #### Atoms don't download anything for now, reason unknown

        return os.path.join(self.output_dir, archive_name)


    def untar_bin_files(self, archive_name):
        def bin_files(members):
            for tarinfo in members:
                if os.path.splitext(tarinfo.name)[1] == ".bin":
                    yield tarinfo

        path = os.path.join(self.output_dir, self.name + '__' + self.line_list)
        with tarfile.open(archive_name, 'r') as tar:
            tar.extractall(path=path , members=bin_files(tar))

        return path


    def opacity_dir_to_netcdf(self, opacity_dir, outpath):
        outpath = self.output_dir + '/' + outpath
        if os.path.exists(outpath):
            print(f"{outpath} already exists")
            return
        temperature_grid = []
        pressure_grid = []
        wl_end = np.inf

        for dirpath, dirnames, filenames in os.walk(opacity_dir):
            for filename in filenames:
                if not filename.endswith('.bin'):
                    continue

                # Wavenumber points from range given in the file names
                temperature = int(filename.split('_')[3])
                sign = 1 if filename.split('_')[4][0] == 'p' else -1
                pressure = 10 ** (sign * float(filename.split('_')[4][1:].split('.')[0]) / 100)

                wl_start = int(filename.split('_')[1])
                wl_end = min(wl_end, int(filename.split('_')[2]))
                wlen = np.arange(wl_start, wl_end, 0.01)

                # catch divide by zero warning here:
                with warnings.catch_warnings():
                    warnings.simplefilter('ignore', RuntimeWarning)

                    # Convert to micron
                    wavelength = 1 / wlen / 1e-4

                unique_wavelengths = wavelength[1:][::-1]
                temperature_grid.append(temperature)
                pressure_grid.append(pressure)

        tgrid = np.sort(list(set(temperature_grid)))
        pgrid = np.sort(list(set(pressure_grid)))

        if len(pgrid) == 1:
            extrapolate_pgrid = True
            pgrid = np.concatenate([pgrid, 10 ** (-1 * np.log10(pgrid))])
        else:
            extrapolate_pgrid = False

        opacity_grid = np.zeros(
            (len(tgrid), len(pgrid), len(unique_wavelengths)), dtype='float32'
        )

        for dirpath, dirnames, filenames in os.walk(opacity_dir):
            for filename in filenames:
                if not filename.endswith('.bin'):
                    continue

                opacity = np.fromfile(
                    os.path.join(dirpath, filename), dtype=np.float32
                )[1:][::-1]

                # Wavenumber points from range given in the file names
                temperature = int(filename.split('_')[3])
                sign = 1 if filename.split('_')[4][0] == 'p' else -1
                pressure = 10 ** (sign * float(filename.split('_')[4][1:].split('.')[0]) / 100)

                temperature_ind = np.argmin(np.abs(tgrid - temperature))
                pressure_ind = np.argmin(np.abs(pgrid - pressure))

                opacity_grid[temperature_ind, pressure_ind, :] = opacity[:len(unique_wavelengths)]

        if extrapolate_pgrid:
            for dirpath, dirnames, filenames in os.walk(opacity_dir):
                for filename in filenames:
                    opacity = np.fromfile(
                        os.path.join(dirpath, filename), dtype=np.float32
                    )[1:][::-1]

                    # Wavenumber points from range given in the file names
                    temperature = int(filename.split('_')[3])
                    # *Flip the sign for the extrapolated grid point in pressure*
                    sign = -1 if filename.split('_')[4][0] == 'p' else 1
                    pressure = 10 ** (sign * float(filename.split('_')[4][1:].split('.')[0]) / 100)

                    temperature_ind = np.argmin(np.abs(tgrid - temperature))
                    pressure_ind = np.argmin(np.abs(pgrid - pressure))

                    opacity_grid[temperature_ind, pressure_ind, :] = opacity

        ds = xr.Dataset(
            data_vars=dict(
                opacity=(["temperature", "pressure", "wavelength"],
                        opacity_grid)
            ),
            coords=dict(
                temperature=(["temperature"], tgrid),
                pressure=(["pressure"], pgrid),
                wavelength=unique_wavelengths
            )
        )
        if not os.path.exists(os.path.dirname(outpath)):
            os.makedirs(os.path.dirname(outpath), exist_ok=True)

        ds.to_netcdf(outpath if outpath.endswith(".nc") else outpath + '.nc',
                    encoding={'opacity': {'dtype': 'float32'}})


    def clean_up(self, bin_dir, archive_name):
        os.remove(archive_name)
        shutil.rmtree(bin_dir)


def download_molecule(
    molecule=None,
    isotopologue=None,
    line_list='first-found',
    temperature_range=None,
    pressure_range=None,
    version=None,
    output_dir="datadir/dace",
    force_download=False,
):
    """
    Download molecular opacity data from DACE.

    .. warning::
        This generates *very* large files. Only run this
        method if you have ~6 GB available per molecule.

    Parameters
    ----------
    isotopologue : str
        For example, "1H2-16O" for water.
    molecule : str
        Common name for the molecule, for example: "H2O"
    line_list : str, default is ``'first-found'``, optional
        For example, "POKAZATEL" for water. By default, the first available
        line list for this isotopologue is chosen.
    temperature_range : tuple, optional
        Tuple of integers specifying the min and max
        temperature requested. Defaults to the full
        range of available temperatures.
    pressure_range : tuple, optional
        Tuple of floags specifying the log base 10 of the
        min and max pressure [bar] requested. Defaults to the full
        range of available pressures.
    version : float, optional
        Version number of the line list in DACE. Defaults to the
        latest version.
    """
    if molecule is not None:
        isotopologue = species_name_to_common_isotopologue_name(molecule)

    available_line_lists = available_opacities.get_molecular_line_lists(isotopologue)
    print(f"Available line lists for {isotopologue}:")
    print(available_line_lists)

    if line_list is None:
        return

    if line_list == 'first-found':
        line_list = sorted(list(available_line_lists)).pop()
        logging.warning(f"Using first-found line list for {isotopologue}: '{line_list}'")

    elif line_list not in available_line_lists:
        raise ValueError(f"The requested '{line_list}' is not in the set of "
                         f"available line lists {available_line_lists}.")

    if version is None:
        version = available_opacities.get_molecular_latest_version(isotopologue, line_list)
        logging.warning(f"Using latest version of the line "
                        f"list '{line_list}' for {isotopologue}: {version}")

    if temperature_range is None or pressure_range is None:
        dace_temp_range, dace_press_range = available_opacities.get_molecular_pT_range(
            isotopologue, line_list
        )

    if temperature_range is None:
        temperature_range = dace_temp_range

    if pressure_range is None:
        pressure_range = dace_press_range

    download_dace = DownloadDace(isotopologue, line_list, temperature_range, pressure_range, version, output_dir, force_download=force_download)
    archive_name = download_dace.dace_download_molecule()

    bin_dir = download_dace.untar_bin_files(archive_name)

    download_dace.opacity_dir_to_netcdf(bin_dir, isotopologue + '__' + line_list + '.nc')
    # clean_up(bin_dir, archive_name)
    return isotopologue + '__' + line_list + '.nc'


def download_atom(atom, charge, line_list='first-found',
                  temperature_range=None, pressure_range=None, version=None, output_dir="tmp"):
    """
    Download atomic opacity data from DACE.

    .. warning::
        This generates *very* large files. Only run this
        method if you have ~6 GB available per molecule.

    Parameters
    ----------
    atom : str
        For example, "Na" for sodium.
    charge : int
        For example, 0 for neutral.
    line_list : str, default is ``'first-found'``, optional
        For example, "Kurucz". By default, the first available
        line list for this atom/charge is chosen.
    temperature_range : tuple, optional
        Tuple of integers specifying the min and max
        temperature requested. Defaults to the full
        range of available temperatures.
    pressure_range : tuple, optional
        Tuple of floags specifying the log base 10 of the
        min and max pressure [bar] requested. Defaults to the full
        range of available pressures.
    version : float, optional
        Version number of the line list in DACE. Defaults to the
        latest version.
    """
    available_line_lists = available_opacities.get_atomic_line_lists(atom)
    print(f"Available line lists for {atom}:")
    print(available_line_lists)

    if line_list is None:
        return

    if line_list == 'first-found':
        line_list = sorted(list(available_line_lists)).pop()
        logging.warning(f"Using first-found line list for {atom}: '{line_list}'")

    elif line_list not in available_line_lists:
        raise ValueError(f"The requested '{line_list}' is not in the set of "
                         f"available line lists {available_line_lists}.")

    if temperature_range is None or pressure_range is None:
        dace_temp_range, dace_press_range = available_opacities.get_atomic_pT_range(
            atom, charge, line_list
        )

    if version is None:
        version = available_opacities.get_atomic_latest_version(atom, charge, line_list)
        logging.warning(f"Using latest version of the line "
                        f"list '{line_list}' for {atom}: {version}")

    if temperature_range is None:
        temperature_range = dace_temp_range

    if pressure_range is None:
        pressure_range = dace_press_range

    download_dace = DownloadDace(atom, line_list, temperature_range, pressure_range, version, output_dir)

    archive_name = download_dace.dace_download_atom(charge)
    bin_dir = download_dace.untar_bin_files(archive_name)

    download_dace.opacity_dir_to_netcdf(bin_dir, atom + '_' + str(int(charge)) + '__' + line_list + '.nc')
    # download_dace.clean_up(bin_dir, archive_name)
