##############################################################
### Python script to visualize a variable in a NetCDF file ###
##############################################################

### This script can display any variable of a NetCDF file.
### The file name, the variable to display and eventually the
### dimension are asked to the user in the terminal.

import os
import readline
import glob
from netCDF4 import Dataset
import matplotlib.pyplot as plt
import numpy as np

##############################################################
### Setup readline for file name autocompletion
def complete(text,state):
    line = readline.get_line_buffer().split()
    # Use glob to find all matching files/directories for the current text
    if '*' not in text:
        text += '*'
    matches = glob.glob(os.path.expanduser(text))
    # Add '/' if the match is a directory
    matches = [match + '/' if os.path.isdir(match) else match for match in matches]
    try:
        return matches[state]
    except IndexError:
        return None

### Function to handle autocomplete for variable names
def complete_variable_names(variable_names):
    def completer(text, state):
        options = [name for name in variable_names if name.startswith(text)]
        if state < len(options):
            return options[state]
        else:
            return None
    return completer

### Function to visualize a variable from a NetCDF file
def visualize_variable():
    # Ask for the NetCDF file name
    readline.set_completer(complete)
    readline.parse_and_bind('tab: complete')
    file = input("Enter the name of the NetCDF file: ")
    
    # Open the NetCDF file
    try:
        dataset = Dataset(file,mode='r')
    except FileNotFoundError:
        print(f"File '{file}' not found.")
        return

    # Display available variables
    variable_names = list(dataset.variables.keys())
    print("Available variables:\n",variable_names)
    
    # Ask for the variable to display
    readline.set_completer(complete_variable_names(variable_names))
    variable_name = input("\nEnter the name of the variable you want to visualize: ")
    
    # Check if the variable exists
    if variable_name not in dataset.variables:
        print(f"Variable '{variable_name}' not found in the dataset.")
        dataset.close()
        return
    
    # Extract the selected variable
    variable = dataset.variables[variable_name][:]
    
    # Extract latitude, longitude and altitude
    latitude = dataset.variables['latitude'][:]
    longitude = dataset.variables['longitude'][:]
    
    # Check if the variable has altitude and time dimensions
    dimensions = dataset.variables[variable_name].dimensions
    print(f"\nDimensions of '{variable_name}': {dimensions}")
    
    # If the variable has a time dimension, ask for the time index
    if 'Time' in dimensions:
        if variable.shape[0] == 1:
            time_index = 0
        else:
            time_index = int(input(f"Enter the time index (0 to {variable.shape[0] - 1}): "))
    else:
        time_index = None
    
    # If the variable has an altitude dimension, ask for the altitude index
    if 'altitude' in dimensions:
        altitude = dataset.variables['altitude'][:]
        altitude_index = int(input(f"Enter the altitude index (0 to {altitude.shape[0] - 1}): "))
    else:
        altitude_index = None
    
    # Prepare the 2D slice for plotting
    if time_index is not None and altitude_index is not None:
        data_slice = variable[time_index,altitude_index,:,:]
    elif time_index is not None:
        data_slice = variable[time_index,:,:]
    elif altitude_index is not None:
        data_slice = variable[altitude_index,:,:]
    else:
        data_slice = variable[:,:]
    
    # Plot the selected variable
    plt.figure(figsize = (10,6))
    plt.contourf(longitude,latitude,data_slice,cmap = 'jet')
    plt.colorbar(label=f"{variable_name.capitalize()} (units)") # Adjust units based on your data
    plt.xlabel('Longitude (degrees)')
    plt.ylabel('Latitude (degrees)')
    plt.title(f"{variable_name.capitalize()} visualization")
    
    # Show the plot
    plt.show()
    
    # Close the NetCDF file
    dataset.close()

### Call the main function
visualize_variable()
