import xarray as xr
import matplotlib.pyplot as plt
import numpy as np


def plot_var(xx, yy, var2d, cmap='magma', extend='both', label='', title='', output="dynamico_plot.png"):
    plt.figure(figsize=(5.5, 5))
    plt.subplots_adjust(left=0.01, bottom=0.01, right=0.96, top=0.94)
    plt.tricontourf(xx, yy, var2d, cmap=cmap)
    plt.colorbar(extend=extend, label=label, pad=0.01)
    plt.title(title, size=14)
    plt.savefig(output)

filename_start = 'start.nc'
filename_startfi = 'startfi.nc'

def open_file(filename="start.nc", file_type="start"):
   file = xr.open_dataset(filename)
   # lon lat names
   if file_type == 'startphy':
      lon_name, lat_name = 'longitude', 'latitude'
   elif file_type == 'start':
      lon_name, lat_name = 'lon_mesh', 'lat_mesh'
   elif file_type == 'hist':
      lon_name, lat_name = 'lon', 'lat'
   else:
      raise 'ERROR: file_type must be start or startphy or hist'
   return file, lon_name, lat_name

file_start, lon_start, lat_start = open_file(filename_start, "start")
file_startfi, lon_startfi, lat_startfi = open_file(filename_startfi, "startphy")

var_name = 'ps'
title = 'Surface Pressure'
unit = "kg/kg"
plot_var(file_start[lon_start], file_start[lat_start], file_start[var_name], label=unit, title=title, output=f"start_{var_name}.png")

var_name = 'n2'
title = 'N2 surf'
unit = "kg/kg"
plot_var(file_startfi[lon_startfi], file_startfi[lat_startfi], file_startfi[var_name], label=unit, title=title, output=f"startfi_{var_name}.png")

