Skip to content

Util Module

Supporting utilities for climate data processing, unit conversions, and visualization.

Overview

The climakitae.util module provides: - General utilities — Helper functions for data manipulation and analysis - Warming level calculations — Global warming level trajectory computations - Unit conversions — Climate variable unit transformation - Colormaps — Custom climate-focused colormaps - Cluster management — Dask distributed computing setup - Logging — Logger configuration for climakitae

Core Utilities

downscaling_method_as_list(downscaling_method)

Function to convert string based radio button values to python list.

Parameters:

Name Type Description Default
downscaling_method str

one of "Dynamical", "Statistical", or "Dynamical+Statistical"

required

Returns:

Name Type Description
method_list list

one of ["Dynamical"], ["Statistical"], or ["Dynamical","Statistical"]

Source code in climakitae/util/utils.py
def downscaling_method_as_list(downscaling_method: str) -> list[str]:
    """Function to convert string based radio button values to python list.

    Parameters
    ----------
    downscaling_method : str
        one of "Dynamical", "Statistical", or "Dynamical+Statistical"

    Returns
    -------
    method_list : list
        one of ["Dynamical"], ["Statistical"], or ["Dynamical","Statistical"]

    """
    method_list = []
    if downscaling_method == "Dynamical+Statistical":
        method_list = ["Dynamical", "Statistical"]
    else:
        method_list = [downscaling_method]
    return method_list

area_average(dset)

Weighted area-average

Parameters:

Name Type Description Default
dset Dataset

one dataset from the catalog

required

Returns:

Type Description
Dataset

sub-setted output data

Source code in climakitae/util/utils.py
def area_average(dset: xr.Dataset) -> xr.Dataset:
    """Weighted area-average

    Parameters
    ----------
    dset : xr.Dataset
        one dataset from the catalog

    Returns
    -------
    xr.Dataset
        sub-setted output data

    """
    weights = np.cos(np.deg2rad(dset.lat))
    if set(["x", "y"]).issubset(set(dset.dims)):
        # WRF data has x,y
        dset = dset.weighted(weights).mean(["x", "y"])
    elif set(["lat", "lon"]).issubset(set(dset.dims)):
        # LOCA data has lat, lon
        dset = dset.weighted(weights).mean(["lat", "lon"])
    return dset

read_csv_file(rel_path, index_col=UNSET, parse_dates=False)

Read CSV file into pandas DataFrame

Parameters:

Name Type Description Default
rel_path str

path to CSV file relative to this util python file

required
index_col str

CSV column to index DataFrame on

UNSET
parse_dates boolean

Whether to have pandas parse the date strings

False

Returns:

Type Description
DataFrame
Source code in climakitae/util/utils.py
def read_csv_file(
    rel_path: str, index_col: str = UNSET, parse_dates: bool = False
) -> pd.DataFrame:
    """Read CSV file into pandas DataFrame

    Parameters
    ----------
    rel_path : str
        path to CSV file relative to this util python file
    index_col : str
        CSV column to index DataFrame on
    parse_dates : boolean
        Whether to have pandas parse the date strings

    Returns
    -------
    pd.DataFrame

    """
    return pd.read_csv(
        _package_file_path(rel_path),
        index_col=None if index_col is UNSET else index_col,
        parse_dates=parse_dates,
        na_values=[
            "",
            "#N/A",
            "#N/A N/A",
            "#NA",
            "-1.#IND",
            "-1.#QNAN",
            "-NaN",
            "-nan",
            "1.#IND",
            "1.#QNAN",
            "<NA>",
            "N/A",
            "NA",
            "NULL",
            "NaN",
            "n/a",
            "nan",
            "null ",
        ],
        keep_default_na=False,
    )

write_csv_file(df, rel_path)

Write CSV file from pandas DataFrame

Parameters:

Name Type Description Default
df DataFrame

pandas DataFrame to write out

required
rel_path str

path to CSV file relative to this util python file

required

Returns:

Type Description
None
Source code in climakitae/util/utils.py
def write_csv_file(df: pd.DataFrame, rel_path: str) -> None:
    """Write CSV file from pandas DataFrame

    Parameters
    ----------
    df : pd.DataFrame
        pandas DataFrame to write out
    rel_path : str
        path to CSV file relative to this util python file

    Returns
    -------
    None

    """
    return df.to_csv(_package_file_path(rel_path))

f_to_k(f)

Convert temperature from degrees Fahrenheit to Kelvin.

Accepts scalars or array-like inputs (returns numpy array for array-like).

Parameters:

Name Type Description Default
f float or array - like

Degrees Fahrenheit.

required

Returns:

Type Description
float or ndarray

Temperature in Kelvin.

Source code in climakitae/util/utils.py
def f_to_k(f: Union[float, Iterable[float]]) -> Union[float, "np.ndarray"]:
    """Convert temperature from degrees Fahrenheit to Kelvin.

    Accepts scalars or array-like inputs (returns numpy array for array-like).

    Parameters
    ----------
    f : float or array-like
        Degrees Fahrenheit.

    Returns
    -------
    float or numpy.ndarray
        Temperature in Kelvin.

    """
    # Use numpy to handle both scalar and array-like inputs
    return (np.asarray(f) - 32.0) * 5.0 / 9.0 + 273.15

get_closest_gridcell(data, lat, lon, print_coords=True)

From input gridded data, get the closest VALID gridcell to a lat, lon coordinate pair.

This function first transforms the lat,lon coords to the gridded data’s projection. Then, it uses xarray’s built in method .sel to get the nearest gridcell.

Parameters:

Name Type Description Default
data DataArray | Dataset

Gridded data (can be backed by numpy or dask arrays)

required
lat float

Latitude or y value of coordinate pair

required
lon float

Longitude or x value of coordinate pair

required
print_coords bool

Print closest coorindates? Default to True. Set to False for backend use.

True

Returns:

Type Description
Dataset | DataArray | None

Grid cell closest to input lat,lon coordinate pair, returns same type as input. The result preserves lazy evaluation if the input was lazy.

See also

xr.DataArray.isel

Source code in climakitae/util/utils.py
def get_closest_gridcell(
    data: xr.Dataset | xr.DataArray, lat: float, lon: float, print_coords: bool = True
) -> xr.Dataset | xr.DataArray | None:
    """From input gridded data, get the closest VALID gridcell to a lat, lon coordinate pair.

    This function first transforms the lat,lon coords to the gridded data’s projection.
    Then, it uses xarray’s built in method .sel to get the nearest gridcell.

    Parameters
    ----------
    data : xr.DataArray | xr.Dataset
        Gridded data (can be backed by numpy or dask arrays)
    lat : float
        Latitude or y value of coordinate pair
    lon : float
        Longitude or x value of coordinate pair
    print_coords : bool, optional
        Print closest coorindates?
        Default to True. Set to False for backend use.

    Returns
    -------
    xr.Dataset | xr.DataArray | None
        Grid cell closest to input lat,lon coordinate pair, returns same type as input.
        The result preserves lazy evaluation if the input was lazy.

    See also
    --------
    xr.DataArray.isel

    """
    # Identify spatial dimensions and transform coordinates
    # lat_dim is y-axis (north-south), lon_dim is x-axis (east-west)
    lat_dim, lon_dim = _get_spatial_dims(data)
    lat_coord, lon_coord = _transform_coords_to_data_crs(
        data, lat, lon, lat_dim, lon_dim
    )

    # Find nearest indices
    lat_idx = data[lat_dim].to_index().get_indexer([lat_coord], method="nearest")[0]
    lon_idx = data[lon_dim].to_index().get_indexer([lon_coord], method="nearest")[0]

    if lat_idx == -1 or lon_idx == -1:
        print("Input coordinate is OUTSIDE of data extent. Returning None.")
        return None

    # Check for valid data at closest gridcell
    gridcell = data.isel({lat_dim: lat_idx, lon_dim: lon_idx})
    test_data = _reduce_to_single_point(gridcell, lat_dim, lon_dim)

    if _check_has_valid_data(test_data):
        if print_coords:
            _print_closest_coords(gridcell, lat, lon, lat_dim, lon_dim)
        return gridcell

    # If closest gridcell is all NaN, search in 3x3 window and average valid cells
    valid_data = _search_nearby_valid_gridcells(
        data, lat_idx, lon_idx, lat_dim, lon_dim
    )

    if len(valid_data) > 0:
        closest_gridcell, coord_values = _average_gridcells_preserve_coords(
            valid_data, lat_dim, lon_dim
        )
        if print_coords:
            _print_closest_coords(
                closest_gridcell,
                lat,
                lon,
                lat_dim,
                lon_dim,
                is_averaged=True,
                coord_values=coord_values,
            )
        return closest_gridcell

    return None

get_closest_gridcells(data, lats, lons, print_coords=True, bbox_buffer=5)

Find the nearest grid cell(s) for given latitude and longitude coordinates.

This function uses vectorized operations to efficiently find closest gridcells for multiple coordinate pairs at once. For a single point, it delegates to get_closest_gridcell.

Parameters:

Name Type Description Default
data DataArray | Dataset

Gridded dataset with (x, y) or (lat, lon) dimensions.

required
lats float | Iterable[float]

Latitude coordinate(s).

required
lons float | Iterable[float]

Longitude coordinate(s).

required
print_coords bool

Print closest coordinates for each point. Default is True. Note: For large numbers of points, printing is automatically suppressed.

True
bbox_buffer int

Number of grid cells to add as buffer around the bounding box when pre-clipping large datasets. Default is 5.

5

Returns:

Type Description
Dataset | DataArray | None

Nearest grid cell(s) or None if no valid match is found. If multiple coordinates are provided, results are concatenated along 'points' dimension.

See Also

get_closest_gridcell

Notes

For large datasets with many target points, this function first clips the data to a bounding box around the target points. This dramatically reduces the Dask task graph complexity and improves performance for downstream operations.

Source code in climakitae/util/utils.py
def get_closest_gridcells(
    data: xr.Dataset | xr.DataArray,
    lats: Iterable[float] | float,
    lons: Iterable[float] | float,
    print_coords: bool = True,
    bbox_buffer: int = 5,
) -> xr.Dataset | xr.DataArray | None:
    """Find the nearest grid cell(s) for given latitude and longitude coordinates.

    This function uses vectorized operations to efficiently find closest gridcells
    for multiple coordinate pairs at once. For a single point, it delegates to
    get_closest_gridcell.

    Parameters
    ----------
    data : xr.DataArray | xr.Dataset
        Gridded dataset with (x, y) or (lat, lon) dimensions.
    lats : float | Iterable[float]
        Latitude coordinate(s).
    lons : float | Iterable[float]
        Longitude coordinate(s).
    print_coords : bool, optional
        Print closest coordinates for each point. Default is True.
        Note: For large numbers of points, printing is automatically suppressed.
    bbox_buffer : int, optional
        Number of grid cells to add as buffer around the bounding box when
        pre-clipping large datasets. Default is 5.

    Returns
    -------
    xr.Dataset | xr.DataArray | None
        Nearest grid cell(s) or None if no valid match is found.
        If multiple coordinates are provided, results are concatenated along 'points' dimension.

    See Also
    --------
    get_closest_gridcell

    Notes
    -----
    For large datasets with many target points, this function first clips the data
    to a bounding box around the target points. This dramatically reduces the Dask
    task graph complexity and improves performance for downstream operations.

    """
    # Convert single values to arrays for uniform handling
    lats_arr = np.atleast_1d(np.asarray(lats))
    lons_arr = np.atleast_1d(np.asarray(lons))

    # Ensure lats and lons have the same length
    if len(lats_arr) != len(lons_arr):
        raise ValueError(
            f"lats and lons must have the same length, got {len(lats_arr)} and {len(lons_arr)}"
        )

    n_points = len(lats_arr)
    print(f"Processing {n_points} coordinate pair(s)...")

    # Suppress per-point printing for large numbers of points
    if n_points > 10:
        print_coords = False

    # Identify spatial dimensions and transform all coordinates at once
    lat_dim, lon_dim = _get_spatial_dims(data)
    print(f"  Spatial dimensions: {lat_dim}, {lon_dim}")

    lat_coords, lon_coords = _transform_coords_to_data_crs_vectorized(
        data, lats_arr, lons_arr, lat_dim, lon_dim
    )

    # Get coordinate arrays from data
    lat_index = data[lat_dim].to_index()
    lon_index = data[lon_dim].to_index()
    print(f"  Data grid size: {len(lat_index)} x {len(lon_index)}")

    # OPTIMIZATION: Pre-clip to bounding box to reduce Dask task graph complexity
    # This is critical for large datasets with many scattered points
    lat_min_idx = lat_index.get_indexer([lat_coords.min()], method="nearest")[0]
    lat_max_idx = lat_index.get_indexer([lat_coords.max()], method="nearest")[0]
    lon_min_idx = lon_index.get_indexer([lon_coords.min()], method="nearest")[0]
    lon_max_idx = lon_index.get_indexer([lon_coords.max()], method="nearest")[0]

    # Add buffer and clamp to valid range
    lat_min_idx = max(0, min(lat_min_idx, lat_max_idx) - bbox_buffer)
    lat_max_idx = min(len(lat_index) - 1, max(lat_min_idx, lat_max_idx) + bbox_buffer)
    lon_min_idx = max(0, min(lon_min_idx, lon_max_idx) - bbox_buffer)
    lon_max_idx = min(len(lon_index) - 1, max(lon_min_idx, lon_max_idx) + bbox_buffer)

    bbox_lat_size = lat_max_idx - lat_min_idx + 1
    bbox_lon_size = lon_max_idx - lon_min_idx + 1
    original_size = len(lat_index) * len(lon_index)
    bbox_size = bbox_lat_size * bbox_lon_size

    # Only pre-clip if it reduces spatial size significantly (>50% reduction)
    if bbox_size < original_size * 0.5:
        print(
            f"  Pre-clipping to bounding box: {bbox_lat_size} x {bbox_lon_size} "
            f"({bbox_size / original_size * 100:.1f}% of original)"
        )
        data = data.isel(
            {
                lat_dim: slice(lat_min_idx, lat_max_idx + 1),
                lon_dim: slice(lon_min_idx, lon_max_idx + 1),
            }
        )
        # Update indices to reference the clipped data
        lat_index = data[lat_dim].to_index()
        lon_index = data[lon_dim].to_index()

    # Find nearest indices for all points at once (vectorized)
    lat_indices = lat_index.get_indexer(lat_coords, method="nearest")
    lon_indices = lon_index.get_indexer(lon_coords, method="nearest")

    # Check for out-of-bounds points
    valid_mask = (lat_indices != -1) & (lon_indices != -1)
    if not np.any(valid_mask):
        print("All input coordinates are OUTSIDE of data extent. Returning None.")
        return None

    if not np.all(valid_mask):
        n_invalid = np.sum(~valid_mask)
        print(
            f"  Warning: {n_invalid} point(s) are outside data extent and will be excluded."
        )

    # Filter to valid indices only
    valid_lat_indices = lat_indices[valid_mask]
    valid_lon_indices = lon_indices[valid_mask]
    valid_lats = lats_arr[valid_mask]
    valid_lons = lons_arr[valid_mask]
    n_valid = len(valid_lat_indices)
    print(f"  Found {n_valid} valid point(s) within data extent")

    # Check for invalid (ocean/masked) points using landmask BEFORE extracting data
    # This avoids loading the full dataset just to check for NaNs
    # SKIP expensive validity check for large point sets (>100 points) - the compute
    # call triggers the entire Dask task graph which is extremely slow
    needs_nan_handling = []
    skip_validity_check = n_valid > 100

    if skip_validity_check:
        print(f"  Skipping validity check for {n_valid} points (too expensive)")
    elif "landmask" in data.coords or "landmask" in getattr(data, "data_vars", {}):
        print("  Checking landmask for valid land points...")
        landmask = data["landmask"]
        # Compute if dask-backed (this is just 2D, so cheap)
        if hasattr(landmask.data, "compute"):
            landmask = landmask.compute()
        landmask_values = landmask.values

        # Check which target points are on land vs water
        is_land = np.array(
            [
                landmask_values[lat_idx, lon_idx] == 1
                for lat_idx, lon_idx in zip(valid_lat_indices, valid_lon_indices)
            ]
        )
        needs_nan_handling = list(np.where(~is_land)[0])
        print(f"  {np.sum(is_land)}/{n_valid} points are on land")
    else:
        # No landmask available - create one from a single time slice
        # Only do this for small point sets since it requires compute()
        print("  No landmask found, checking single time slice for validity...")
        if isinstance(data, xr.Dataset):
            first_var = list(data.data_vars)[0]
            check_data = data[first_var]
        else:
            check_data = data

        # Get a single time slice to create validity mask
        if "time" in check_data.dims:
            single_slice = check_data.isel(time=0)
        else:
            single_slice = check_data

        # Reduce any remaining non-spatial dims
        for dim in list(single_slice.dims):
            if dim not in [lat_dim, lon_dim]:
                single_slice = single_slice.isel({dim: 0})

        # Compute the validity mask (2D)
        if hasattr(single_slice.data, "compute"):
            single_slice = single_slice.compute()

        validity_mask = ~np.isnan(single_slice.values)

        # Check which target points are valid
        is_valid = np.array(
            [
                validity_mask[lat_idx, lon_idx]
                for lat_idx, lon_idx in zip(valid_lat_indices, valid_lon_indices)
            ]
        )
        needs_nan_handling = list(np.where(~is_valid)[0])
        print(f"  {np.sum(is_valid)}/{n_valid} points have valid data")

    # For invalid points, find valid neighbor indices using the 2D mask
    # This is done BEFORE extracting the full timeseries data
    final_lat_indices = valid_lat_indices.copy()
    final_lon_indices = valid_lon_indices.copy()

    if len(needs_nan_handling) > 0:
        print(f"  {len(needs_nan_handling)} point(s) need neighbor search...")

        # Get the validity mask for neighbor searching
        if "landmask" in data.coords or "landmask" in getattr(data, "data_vars", {}):
            mask_2d = landmask_values == 1
        else:
            mask_2d = validity_mask

        lat_size = data.sizes[lat_dim]
        lon_size = data.sizes[lon_dim]
        points_fixed = 0

        for point_idx in needs_nan_handling:
            lat_idx = int(valid_lat_indices[point_idx])
            lon_idx = int(valid_lon_indices[point_idx])

            # Search 3x3 window for valid neighbor using the 2D mask
            found_valid = False
            for di in range(-1, 2):
                for dj in range(-1, 2):
                    ni, nj = lat_idx + di, lon_idx + dj
                    if 0 <= ni < lat_size and 0 <= nj < lon_size:
                        if mask_2d[ni, nj]:
                            # Found a valid neighbor - use its indices
                            final_lat_indices[point_idx] = ni
                            final_lon_indices[point_idx] = nj
                            found_valid = True
                            points_fixed += 1
                            break
                if found_valid:
                    break

        print(
            f"  Fixed {points_fixed}/{len(needs_nan_handling)} points with valid neighbors"
        )

    # Now extract the full timeseries data using the corrected indices
    print("  Extracting gridcells...")

    # Use vectorized advanced indexing with DataArray indexers
    # This is more efficient than stack+isel because it doesn't create a complex MultiIndex
    lat_indexer = xr.DataArray(final_lat_indices, dims=["points"])
    lon_indexer = xr.DataArray(final_lon_indices, dims=["points"])

    # Select all points at once using vectorized indexing
    result = data.isel({lat_dim: lat_indexer, lon_dim: lon_indexer})

    # Add coordinate information for each point
    actual_lats = data[lat_dim].values[final_lat_indices]
    actual_lons = data[lon_dim].values[final_lon_indices]

    # Assign coordinates to the result
    result = result.assign_coords(
        {
            lat_dim: ("points", actual_lats),
            lon_dim: ("points", actual_lons),
        }
    )

    # Print coordinates if requested
    if print_coords:
        for i in range(n_valid):
            print(
                f"  Point {i+1}: ({valid_lats[i]:.4f}, {valid_lons[i]:.4f}) -> "
                f"({actual_lats[i]:.4f}, {actual_lons[i]:.4f})"
            )

    # Reorder dimensions to put 'points' at the end
    if isinstance(result, xr.Dataset):
        first_var = list(result.data_vars)[0]
        all_dims = list(result[first_var].dims)
    else:
        all_dims = list(result.dims)

    if "points" in all_dims:
        all_dims.remove("points")
        all_dims.append("points")
        result = result.transpose(*all_dims)

    print(f"Done! Extracted {n_valid} gridcell(s)")
    return result

julianDay_to_date(julday, year=None, return_type='str', str_format='%b-%d')

Convert julian day of year to a date object or formatted string.

Parameters:

Name Type Description Default
julday int

Julian day (day of year)

required
year int

Year to use. If None, uses current year or a leap year (2024) based on needs. Default is None.

None
return_type str

Type of return value: - "str": formatted string (default) - "datetime": datetime object - "date": date object

'str'
str_format str

String format of output date when return_type is "str". Default is "%b-%d" which outputs format like "Jan-01".

'%b-%d'

Returns:

Name Type Description
date str, datetime.datetime, or datetime.date

Julian day converted to specified format or object

Examples:

>>> julianDay_to_date(1)
'Jan-01'
>>> julianDay_to_date(32, year=2023, return_type="date")
datetime.date(2023, 2, 1)
>>> julianDay_to_date(60, year=2024, str_format="%Y-%m-%d")
'2024-02-29'
Source code in climakitae/util/utils.py
def julianDay_to_date(
    julday: int, year: int = None, return_type: str = "str", str_format: str = "%b-%d"
) -> Union[str, datetime.datetime, datetime.date]:
    """Convert julian day of year to a date object or formatted string.

    Parameters
    ----------
    julday : int
        Julian day (day of year)
    year : int, optional
        Year to use. If None, uses current year or a leap year (2024) based on needs.
        Default is None.
    return_type : str, optional
        Type of return value:
        - "str": formatted string (default)
        - "datetime": datetime object
        - "date": date object
    str_format : str, optional
        String format of output date when return_type is "str".
        Default is "%b-%d" which outputs format like "Jan-01".

    Returns
    -------
    date : str, datetime.datetime, or datetime.date
        Julian day converted to specified format or object

    Examples
    --------
    >>> julianDay_to_date(1)
    'Jan-01'
    >>> julianDay_to_date(32, year=2023, return_type="date")
    datetime.date(2023, 2, 1)
    >>> julianDay_to_date(60, year=2024, str_format="%Y-%m-%d")
    '2024-02-29'

    """
    # Determine which year to use
    if year is None:
        year = datetime.datetime.now().year

    # Create datetime object from julian day
    date_obj = datetime.datetime.strptime(f"{year}.{julday}", "%Y.%j")

    # Return appropriate type
    match return_type:
        case "str":
            return date_obj.strftime(str_format)
        case "datetime":
            return date_obj
        case "date":
            return date_obj.date()
        case _:
            raise ValueError("return_type must be 'str', 'datetime', or 'date'")

readable_bytes(b)

Return the given bytes as a human friendly KB, MB, GB, or TB string.

Parameters:

Name Type Description Default
b int

Size in bytes.

required

Returns:

Type Description
str
Code from stackoverflow: https://stackoverflow.com/questions/12523586/python-format-size-application-converting-b-to-kb-mb-gb-tb
Source code in climakitae/util/utils.py
def readable_bytes(b: int) -> str:
    """Return the given bytes as a human friendly KB, MB, GB, or TB string.

    Parameters
    ----------
    b : int
        Size in bytes.

    Returns
    -------
    str

    Code from stackoverflow: https://stackoverflow.com/questions/12523586/python-format-size-application-converting-b-to-kb-mb-gb-tb

    """
    b = float(b)
    kb = 1024
    mb = kb**2  # 1,048,576
    gb = kb**3  # 1,073,741,824
    tb = kb**4  # 1,099,511,627,776

    match b:
        case _ if b < kb:
            return f"{b} bytes"
        case _ if kb <= b < mb:
            return f"{b / kb:.2f} KB"
        case _ if mb <= b < gb:
            return f"{b / mb:.2f} MB"
        case _ if gb <= b < tb:
            return f"{b / gb:.2f} GB"
        case _ if tb <= b:
            return f"{b / tb:.2f} TB"

reproject_data(xr_da, proj='EPSG:4326', fill_value=np.nan)

Reproject xr.DataArray using rioxarray.

Parameters:

Name Type Description Default
xr_da DataArray

2-or-3-dimensional DataArray, with 2 spatial dimensions

required
proj str

proj to use for reprojection (default to "EPSG:4326"-- lat/lon coords)

'EPSG:4326'
fill_value float

fill value (default to np.nan)

nan

Returns:

Name Type Description
data_reprojected DataArray

2-or-3-dimensional reprojected DataArray

Raises:

Type Description
ValueError

if input data does not have spatial coords x,y

ValueError

if input data has more than 5 dimensions

Source code in climakitae/util/utils.py
def reproject_data(
    xr_da: xr.DataArray, proj: str = "EPSG:4326", fill_value: float = np.nan
) -> xr.DataArray:
    """Reproject xr.DataArray using rioxarray.

    Parameters
    ----------
    xr_da : xr.DataArray
        2-or-3-dimensional DataArray, with 2 spatial dimensions
    proj : str
        proj to use for reprojection (default to "EPSG:4326"-- lat/lon coords)
    fill_value : float
        fill value (default to np.nan)

    Returns
    -------
    data_reprojected : xr.DataArray
        2-or-3-dimensional reprojected DataArray

    Raises
    ------
    ValueError
        if input data does not have spatial coords x,y
    ValueError
        if input data has more than 5 dimensions

    """

    def _reproject_data_4D(
        data: xr.DataArray,
        reproject_dim: str,
        proj: str = "EPSG:4326",
        fill_value: float = np.nan,
    ) -> xr.DataArray:
        """Reproject 4D xr.DataArray across an input dimension

        Parameters
        ----------
        data : xr.DataArray
            4-dimensional DataArray, with 2 spatial dimensions
        reproject_dim : str
            name of dimensions to use
        proj : str
            proj to use for reprojection (default to "EPSG:4326"-- lat/lon coords)
        fill_value : float
            fill value (default to np.nan)

        Returns
        -------
        data_reprojected : xr.DataArray
            4-dimensional reprojected DataArray

        """
        rp_list = []
        for i in range(len(data[reproject_dim])):
            dp_i = data[i].rio.reproject(
                proj, nodata=fill_value
            )  # Reproject each index in that dimension
            rp_list.append(dp_i)
        data_reprojected = xr.concat(
            rp_list, dim=reproject_dim
        )  # Concat along reprojection dim to get entire dataset reprojected
        return data_reprojected

    def _reproject_data_5D(
        data: xr.DataArray,
        reproject_dim: list[str],
        proj: str = "EPSG:4326",
        fill_value: float = np.nan,
    ) -> xr.DataArray:
        """Reproject 5D xr.DataArray across two input dimensions

        Parameters
        ----------
        data : xr.DataArray
            5-dimensional DataArray, with 2 spatial dimensions
        reproject_dim : list
            list of str dimension names to use
        proj : str
            proj to use for reprojection (default to "EPSG:4326"-- lat/lon coords)
        fill_value : float
            fill value (default to np.nan)

        Returns
        -------
        data_reprojected : xr.DataArray
            5-dimensional reprojected DataArray

        """
        rp_list_j = []
        reproject_dim_j = reproject_dim[0]
        for j in range(len(data[reproject_dim_j])):
            rp_list_i = []
            reproject_dim_i = reproject_dim[1]
            for i in range(len(data[reproject_dim_i])):
                dp_i = data[j, i].rio.reproject(
                    proj, nodata=fill_value
                )  # Reproject each index in that dimension
                rp_list_i.append(dp_i)
            data_reprojected_i = xr.concat(
                rp_list_i, dim=reproject_dim_i
            )  # Concat along reprojection dim to get entire dataset reprojected
            rp_list_j.append(data_reprojected_i)
        data_reprojected = xr.concat(rp_list_j, dim=reproject_dim_j)
        return data_reprojected

    # Raise error if data doesn't have spatial dimensions x,y
    if not set(["x", "y"]).issubset(xr_da.dims):
        raise ValueError(
            (
                "Input DataArray cannot be reprojected because it"
                " does not contain spatial dimensions x,y"
            )
        )

    # Drop non-dimension coords. Will cause error with rioxarray
    coords = [coord for coord in xr_da.coords if coord not in xr_da.dims]
    data = xr_da.drop_vars(coords)

    # Re-write crs to data using original dataset
    data.rio.write_crs(xr_da.rio.crs, inplace=True)

    # Get non-spatial dimensions
    non_spatial_dims = [dim for dim in data.dims if dim not in ["x", "y"]]

    # test for different dims
    numofdims = len(data.dims)
    # 2 or 3D DataArray
    match numofdims:
        case numofdims if numofdims <= 3:
            data_reprojected = data.rio.reproject(proj, nodata=fill_value)
        # 4D DataArray
        case 4:
            data_reprojected = _reproject_data_4D(
                data=data,
                reproject_dim=non_spatial_dims[0],
                proj=proj,
                fill_value=fill_value,
            )
        # 5D DataArray
        case 5:
            data_reprojected = _reproject_data_5D(
                data=data,
                reproject_dim=non_spatial_dims[:-1],
                proj=proj,
                fill_value=fill_value,
            )
        case _:
            raise ValueError(
                "DataArrays with dimensions greater than 5 are not currently supported"
            )

    # Reassign attribute to reflect reprojection
    data_reprojected.attrs["grid_mapping"] = proj
    return data_reprojected

compute_annual_aggreggate(data, name, num_grid_cells)

Calculates the annual sum of HDD and CDD

Parameters:

Name Type Description Default
data DataArray
required
name str
required
num_grid_cells int
required

Returns:

Name Type Description
annual_ag DataArray
Source code in climakitae/util/utils.py
def compute_annual_aggreggate(
    data: xr.DataArray, name: str, num_grid_cells: int
) -> xr.DataArray:
    """Calculates the annual sum of HDD and CDD

    Parameters
    ----------
    data : xr.DataArray
    name : str
    num_grid_cells : int

    Returns
    -------
    annual_ag : xr.DataArray

    """
    annual_ag = data.squeeze().groupby("time.year").sum(["time"])  # Aggregate annually
    annual_ag = annual_ag / num_grid_cells  # Divide by number of gridcells
    annual_ag.name = name  # Give new name to dataset
    return annual_ag

compute_multimodel_stats(data)

Calculates model mean, min, max, median across simulations

Used in heat_index.ipynb and degree_days.ipynb

Parameters:

Name Type Description Default
data DataArray
required

Returns:

Name Type Description
stats_concat DataArray
Source code in climakitae/util/utils.py
def compute_multimodel_stats(data: xr.DataArray) -> xr.DataArray:
    """Calculates model mean, min, max, median across simulations

    Used in heat_index.ipynb and degree_days.ipynb

    Parameters
    ----------
    data : xr.DataArray

    Returns
    -------
    stats_concat : xr.DataArray

    """
    # Can hard-code as "sim" once DFU/degree_days.ipynb updated for new core.
    # But keeping "simulation" as an option until then.
    if "simulation" in data.dims:
        sim_dim = "simulation"
    else:
        sim_dim = "sim"

    # Compute mean across simulation dimensions and add is as a coordinate
    sim_mean = (
        data.mean(dim=sim_dim)
        .assign_coords({sim_dim: "simulation mean"})
        .expand_dims(sim_dim)
    )

    # Compute multimodel min
    sim_min = (
        data.min(dim=sim_dim)
        .assign_coords({sim_dim: "simulation min"})
        .expand_dims(sim_dim)
    )

    # Compute multimodel max
    sim_max = (
        data.max(dim=sim_dim)
        .assign_coords({sim_dim: "simulation max"})
        .expand_dims(sim_dim)
    )

    # Compute multimodel median
    sim_median = (
        data.median(dim=sim_dim)
        .assign_coords({sim_dim: "simulation median"})
        .expand_dims(sim_dim)
    )

    # Add to main dataset
    stats_concat = xr.concat(
        [data, sim_mean, sim_min, sim_max, sim_median], dim=sim_dim
    )
    stats_concat.attrs["name"] = data.name
    return stats_concat

trendline(data, kind='mean')

Calculates treadline of the multi-model mean or median.

Parameters:

Name Type Description Default
data Dataarray
required
kind str

Options are 'mean' and 'median'

'mean'

Returns:

Name Type Description
trendline DataArray
Note
  1. Development note: If an additional option to trendline 'kind' is required, compute_multimodel_stats must be modified to update optionality.
Source code in climakitae/util/utils.py
def trendline(data: xr.DataArray, kind: str = "mean") -> xr.DataArray:
    """Calculates treadline of the multi-model mean or median.

    Parameters
    ----------
    data : xr.Dataarray
    kind : str , optional
        Options are 'mean' and 'median'

    Returns
    -------
    trendline : xr.DataArray

    Note
    ----
    1. Development note: If an additional option to trendline 'kind' is required,
    compute_multimodel_stats must be modified to update optionality.

    """
    # Can hard-code as "sim" once DFU/degree_days.ipynb updated for new core.
    # But keeping "simulation" as an option until then.
    if "simulation" in data.dims:
        sim_dim = "simulation"
    else:
        sim_dim = "sim"

    ret_trendline = xr.Dataset()
    match kind:
        case "mean":
            if "simulation mean" not in data[sim_dim]:
                raise ValueError(
                    "Invalid data provided, please pass the multimodel stats from compute_multimodel_stats"
                )

            data_sim_mean = data.sel({sim_dim: "simulation mean"})
            m, b = data_sim_mean.polyfit(dim="year", deg=1).polyfit_coefficients.values
            ret_trendline = m * data_sim_mean.year + b  # y = mx + b

        case "median":
            if "simulation median" not in data[sim_dim]:
                raise ValueError(
                    "Invalid data provided, please pass the multimodel stats from compute_multimodel_stats"
                )

            data_sim_med = data.sel({sim_dim: "simulation median"})
            m, b = data_sim_med.polyfit(dim="year", deg=1).polyfit_coefficients.values
            ret_trendline = m * data_sim_med.year + b  # y = mx + b
        case _:
            raise ValueError(
                "Invalid kind provided, please pass either 'mean' or 'median' as the kind"
            )
    ret_trendline.name = "trendline"
    return ret_trendline

combine_hdd_cdd(data)

Drops specific unneeded coords from HDD/CDD data, independent of station or gridded data source

Parameters:

Name Type Description Default
data DataArray
required

Returns:

Name Type Description
data DataArray
Source code in climakitae/util/utils.py
def combine_hdd_cdd(data: xr.DataArray) -> xr.DataArray:
    """Drops specific unneeded coords from HDD/CDD data, independent of station or gridded data source

    Parameters
    ----------
    data : xr.DataArray

    Returns
    -------
    data : xr.DataArray

    """
    if data.name not in [
        "Annual Heating Degree Days (HDD)",
        "Annual Cooling Degree Days (CDD)",
        "Heating Degree Hours",
        "Cooling Degree Hours",
    ]:
        raise ValueError(
            "Invalid data provided, please pass cooling/heating degree data"
        )

    to_drop = ["scenario", "Lambert_Conformal", "variable"]
    for coord in to_drop:
        if coord in data.coords:
            data = data.drop_vars(coord)

    return data

summary_table(data)

Helper function to organize dataset object into a pandas dataframe for ease.

Parameters:

Name Type Description Default
data Dataset
required

Returns:

Name Type Description
df DataFrame

df is organized so that the simulations are stacked in individual columns by year/time

Source code in climakitae/util/utils.py
def summary_table(data: xr.Dataset) -> pd.DataFrame:
    """Helper function to organize dataset object into a pandas dataframe for ease.

    Parameters
    ----------
    data : xr.Dataset

    Returns
    -------
    df : pd.DataFrame
        df is organized so that the simulations are stacked in individual columns by year/time

    """

    # Identify whether the temporal dimension is "time" or "year"
    if "time" in data.dims:
        df = data.drop_vars(
            ["lakemask", "landmask", "lat", "lon", "Lambert_Conformal", "x", "y"]
        ).to_dataframe(dim_order=["time", "sim"])

        df = df.unstack()
        df = df.sort_values(by=["time"])

    elif "year" in data.dims:
        df = data.drop_vars(
            ["lakemask", "landmask", "lat", "lon", "Lambert_Conformal", "x", "y"]
        ).to_dataframe(dim_order=["year", "sim"])

        df = df.unstack()
        df = df.sort_values(by=["year"])

    return df

convert_to_local_time(data, lon=UNSET, lat=UNSET)

Convert time dimension from UTC to local time for the grid or station.

Parameters:

Name Type Description Default
data DataArray or Dataset

Input data.

required
lon float

Mean longitude of dataset if no lat/lon coordinates

UNSET
lat float

Mean latitude of dataset if no lat/lon coordinates

UNSET

Returns:

Type Description
DataArray or Dataset

Data with converted time coordinate.

Source code in climakitae/util/utils.py
def convert_to_local_time(
    data: xr.DataArray | xr.Dataset,
    lon: float = UNSET,
    lat: float = UNSET,
) -> xr.DataArray | xr.Dataset:
    """Convert time dimension from UTC to local time for the grid or station.

    Parameters
    ----------
    data : xr.DataArray or xr.Dataset
        Input data.
    lon : float
        Mean longitude of dataset if no lat/lon coordinates
    lat : float
        Mean latitude of dataset if no lat/lon coordinates

    Returns
    -------
    xr.DataArray or xr.Dataset
        Data with converted time coordinate.

    """

    # Only converting hourly data
    if not (frequency := data.attrs.get("frequency", None)):
        # Make a guess at frequency
        timestep = pd.Timedelta(
            data.time[1].item() - data.time[0].item()
        ).total_seconds()
        match timestep:
            case 3600:
                frequency = "hourly"
            case 86400:
                frequency = "daily"
            case _ if timestep > 86400:
                frequency = "monthly"

    # If timescale is not hourly, no need to convert
    if frequency != "hourly":
        print(
            "This dataset's timescale is not granular enough to covert to local time. Local timezone conversion requires hourly data."
        )
        return data

    # Find out if Stations or Gridded type
    if not (data_type := data.attrs.get("data_type", None)):
        if isinstance(data, xr.core.dataarray.DataArray):
            print(
                "Data Array attribute 'data_type' not found. Please set 'data_type' to 'Stations' or 'Gridded'."
            )
            return data
        else:
            try:
                # Grab from one of data arrays in dataset
                variable = list(data.keys())[0]
                data_type = data[variable].attrs["data_type"]
            except KeyError:
                print(
                    f"Could not find attribute 'data_type' attribute set in {variable} attributes. Please set data_type attribute."
                )
                return data

    # Get latitude/longitude information
    match data_type:
        case "Stations":
            # Read stations database
            stations_df = pd.read_csv(HADISD_STATIONS_URL)
            stations_df = stations_df.drop(columns=["Unnamed: 0"])

            # Filter by selected station(s) - assume first station if multiple
            match data:
                case xr.DataArray():
                    station_name = data.name
                case xr.Dataset():
                    # Grab first one
                    station_name = list(data.keys())[0]
                case _:
                    print(
                        f"Invalid data type {type(data)}. Please provide xarray DataArray or Dataset."
                    )
                    return data
            station_data = stations_df[stations_df["station"] == station_name]
            if len(station_data) == 0:
                print(
                    f"Station {data.name} not found in Stations CSV. Please set Data Array name to valid station name."
                )
                return data
            lat = station_data["LAT_Y"].values[0]
            lon = station_data["LON_X"].values[0]

        case "Gridded":
            # if both lat and lon are set, can move on to timezone finding.
            if (lat is UNSET) or (lon is UNSET):
                try:
                    # Finding central lat/lon coordinates
                    lat_idx = len(data.lat) // 2
                    lon_idx = len(data.lon) // 2
                    lat = data.lat.isel(lat=lat_idx).item()
                    lon = data.lon.isel(lon=lon_idx).item()
                except AttributeError:
                    print(
                        "lat/lon coordinates not found in data. Please pass in data with 'lon' and 'lat' coordinates or set both 'lon' and 'lat' arguments."
                    )
                    return data

        case _:
            print(
                "Invalid data type attribute. Data type should be 'Stations' or 'Gridded'."
            )
            return data

    # Find timezone for the coordinates
    tf = TimezoneFinder()
    local_tz = tf.timezone_at(lng=lon, lat=lat)

    # Change datetime objects to local time
    new_time = (
        pd.DatetimeIndex(data.time)
        .tz_localize("UTC")
        .tz_convert(local_tz)
        .tz_localize(None)
        .astype("datetime64[ns]")
    )
    data["time"] = new_time

    print(f"Data converted to {local_tz} timezone.")

    # Add timezone attribute to data
    match data:
        case xr.DataArray():
            data = data.assign_attrs({"timezone": local_tz})
        case xr.Dataset():
            variables = list(data.keys())
            for variable in variables:
                data[variable] = data[variable].assign_attrs({"timezone": local_tz})
        case _:
            print(f"Invalid data type {type(data)}. Could not set timezone attribute.")

    return data

add_dummy_time_to_wl(wl_da, freq_name='daily')

Replace the [hours/days/months]_from_center or time_delta dimension in a DataArray returned from WarmingLevels with a dummy time index for calculations with tools that require a time dimension.

Parameters:

Name Type Description Default
wl_da DataArray

The input Warming Levels DataArray. It is expected to have a time-based dimension which typically includes "from_center" in its name or time_delta indicating the time dimension in relation to the year that the given warming level is reached per simulation.

required
freq_name str

The frequency name to use when time_delta is the time dimension. Options are "hourly", "daily", or "monthly". Default is "daily".

'daily'

Returns:

Type Description
DataArray

A modified version of the input DataArray with the original time dimension replaced by a dummy time series. The new dimension will be named "time".

Notes
  • The function looks for the dimension name containing "from_center" to identify the time-based dimension.
  • It supports creating dummy time series with frequencies of hours, days, or months, based on the prefix of the dimension name.
  • The dummy time series starts from "2000-01-01".
Source code in climakitae/util/utils.py
def add_dummy_time_to_wl(wl_da: xr.DataArray, freq_name="daily") -> xr.DataArray:
    """Replace the [hours/days/months]_from_center or time_delta dimension in a DataArray returned from WarmingLevels with a dummy time index for calculations with tools that require a time dimension.

    Parameters
    ----------
    wl_da : xr.DataArray
        The input Warming Levels DataArray. It is expected to have a time-based dimension which typically includes "from_center"
        in its name or time_delta indicating the time dimension in relation to the year that the given warming level is reached per simulation.
    freq_name : str, optional
        The frequency name to use when time_delta is the time dimension. Options are "hourly", "daily", or "monthly". Default is "daily".

    Returns
    -------
    xr.DataArray
        A modified version of the input DataArray with the original time dimension replaced by a dummy time series. The new dimension
        will be named "time".

    Notes
    -----
    - The function looks for the dimension name containing "from_center" to identify the time-based dimension.
    - It supports creating dummy time series with frequencies of hours, days, or months, based on the prefix of the dimension name.
    - The dummy time series starts from "2000-01-01".

    """
    ### Adjusting the time index into a dummy time-series for counting

    # Finding time-based dimension
    wl_time_dim = ""

    for dim in wl_da.dims:
        if dim == "time_delta":
            wl_time_dim = "time_delta"
        elif "from_center" in dim:
            wl_time_dim = dim

    if wl_time_dim == "":
        raise ValueError(
            "DataArray does not contain necessary warming level information."
        )

    # Determine time frequency name and pandas freq string mapping
    if wl_time_dim == "time_delta":

        try:
            time_freq_name = wl_da.frequency
        except AttributeError:
            time_freq_name = freq_name

        name_to_freq = {
            "1hr": "h",
            "hourly": "h",
            "daily": "D",
            "day": "D",
            "mon": "MS",
            "monthly": "MS",
        }

    else:
        time_freq_name = wl_time_dim.split("_")[0]
        name_to_freq = {"hours": "h", "days": "D", "months": "MS"}

    freq = name_to_freq[time_freq_name]

    # Number of time units per normal year
    num_time_units_per_year = {"h": 8760, "D": 365, "MS": 12}

    # Calculate total number of units in wl_da along wl_time_dim
    len_time = len(wl_da[wl_time_dim])

    # Calculate approximate number of years spanned by data
    years_span = len_time / num_time_units_per_year[freq]
    start_year = 2000
    end_year = int(start_year + years_span - 1)

    # Calculate total leap days in the period
    total_leap_days = sum(
        calendar.isleap(year) for year in range(start_year, end_year + 1)
    )

    # Adjust number of periods to add leap day hours if hourly, else add leap days as periods
    extra_periods = total_leap_days * 24 if freq == "h" else total_leap_days

    # Edge cases:
    # if total time passed in is less than 60 days (when Feb 29th is), then don't add extra_periods
    # if we're looking at monthly data, then don't add extra_periods
    if (
        (freq == "h" and len_time < 24 * 60)
        or (freq == "D" and len_time < 60)
        or (freq == "MS")
    ):
        extra_periods = 0

    # Create the dummy timestamps including leap day adjustments
    timestamps = pd.date_range(
        start="2000-01-01", periods=len_time + extra_periods, freq=freq
    )

    # Filter out leap days (Feb 29)
    timestamps = timestamps[~((timestamps.month == 2) & (timestamps.day == 29))]

    # Replacing WL timestamps with dummy timestamps so that calculations from tools like thresholds_tools
    # can be computed on a DataArray with a time dimension
    wl_da = wl_da.assign_coords({wl_time_dim: timestamps}).rename({wl_time_dim: "time"})
    return wl_da

downscaling_method_to_activity_id(downscaling_method, reverse=False)

Convert downscaling method to activity id to match catalog names

Parameters:

Name Type Description Default
downscaling_method str

Downscaling method

required
reverse boolean

Set reverse=True to get downscaling method from input activity_id Default to False

False

Returns:

Type Description
str
Source code in climakitae/util/utils.py
def downscaling_method_to_activity_id(
    downscaling_method: str, reverse: bool = False
) -> str:
    """Convert downscaling method to activity id to match catalog names

    Parameters
    ----------
    downscaling_method : str
        Downscaling method
    reverse : boolean, optional
        Set reverse=True to get downscaling method from input activity_id
        Default to False

    Returns
    -------
    str

    """
    downscaling_dict = {"Dynamical": "WRF", "Statistical": "LOCA2"}

    if reverse:
        downscaling_dict = {v: k for k, v in downscaling_dict.items()}
    return downscaling_dict[downscaling_method]

resolution_to_gridlabel(resolution, reverse=False)

Convert resolution format to grid_label format matching catalog names.

Parameters:

Name Type Description Default
resolution str

resolution

required
reverse boolean

Set reverse=True to get resolution format from input grid_label. Default to False

False

Returns:

Type Description
str
Source code in climakitae/util/utils.py
def resolution_to_gridlabel(resolution: str, reverse: bool = False) -> str:
    """Convert resolution format to grid_label format matching catalog names.

    Parameters
    ----------
    resolution : str
        resolution
    reverse : boolean, optional
        Set reverse=True to get resolution format from input grid_label.
        Default to False

    Returns
    -------
    str

    """
    res_dict = {"45 km": "d01", "9 km": "d02", "3 km": "d03"}

    if reverse:
        res_dict = {v: k for k, v in res_dict.items()}
    return res_dict[resolution]

timescale_to_table_id(timescale, reverse=False)

Convert resolution format to table_id format matching catalog names.

Parameters:

Name Type Description Default
timescale str
required
reverse boolean

Set reverse=True to get resolution format from input table_id. Default to False

False

Returns:

Type Description
str
Source code in climakitae/util/utils.py
def timescale_to_table_id(timescale: str, reverse: bool = False) -> str:
    """Convert resolution format to table_id format matching catalog names.

    Parameters
    ----------
    timescale : str
    reverse : boolean, optional
        Set reverse=True to get resolution format from input table_id.
        Default to False

    Returns
    -------
    str

    """
    # yearly max is not an option in the Selections GUI, but its included here to make parsing through the data easier for the non-GUI data access/view options
    timescale_dict = {
        "monthly": "mon",
        "daily": "day",
        "hourly": "1hr",
        "yearly_max": "yrmax",
    }

    if reverse:
        timescale_dict = {v: k for k, v in timescale_dict.items()}
    return timescale_dict[timescale]

scenario_to_experiment_id(scenario, reverse=False)

Convert scenario format to experiment_id format matching catalog names.

Parameters:

Name Type Description Default
scenario str
required
reverse boolean

Set reverse=True to get scenario format from input experiement_id. Default to False

False

Returns:

Type Description
str
Source code in climakitae/util/utils.py
def scenario_to_experiment_id(scenario: str, reverse: bool = False) -> str:
    """Convert scenario format to experiment_id format matching catalog names.

    Parameters
    ----------
    scenario : str
    reverse : boolean, optional
        Set reverse=True to get scenario format from input experiement_id.
        Default to False

    Returns
    -------
    str

    """
    scenario_dict = {
        "Historical Reconstruction": "reanalysis",
        "Historical Climate": "historical",
        "SSP 2-4.5": "ssp245",
        "SSP 5-8.5": "ssp585",
        "SSP 3-7.0": "ssp370",
    }

    if reverse:
        scenario_dict = {v: k for k, v in scenario_dict.items()}
    return scenario_dict[scenario]

clip_to_shapefile(data, shapefile_path, feature=(), name='user-defined', **kwargs)

Use a shapefile to select an area subset of AE data.

By default, this function will clip the data to the area covered by all features in the shapefile. To clip to specific features, use the feature keyword.

Parameters:

Name Type Description Default
data Dataset | DataArray

Data to be clipped.

required
shapefile_path str

Filepath to shapefile. Shapefile must include valid CRS.

required
feature tuple(str, str | int | float | list)

Tuple containing attribute name and value(s) for target feature(s) (optional).

()
name str

Location name to record in data attributes if 'feature' parameter is not set (optional).

'user-defined'
**kwargs

Additional arguments to pass to the rioxarray clip function

{}

Returns:

Name Type Description
clipped Dataset | DataArray

Returns same type as 'data', but grid is clipped to shapefile feature(s).

Source code in climakitae/util/utils.py
def clip_to_shapefile(
    data: xr.Dataset | xr.DataArray,
    shapefile_path: str,
    feature: tuple[str, Any] = (),
    name: str = "user-defined",
    **kwargs,
) -> xr.Dataset | xr.DataArray:
    """Use a shapefile to select an area subset of AE data.

    By default, this function will clip the data to the area covered by all features in
    the shapefile. To clip to specific features, use the feature keyword.

    Parameters
    ----------
    data : xr.Dataset | xr.DataArray
        Data to be clipped.
    shapefile_path : str
        Filepath to shapefile. Shapefile must include valid CRS.
    feature : tuple(str, str | int | float | list)
        Tuple containing attribute name and value(s) for target feature(s) (optional).
    name : str
        Location name to record in data attributes if 'feature' parameter is not set (optional).
    **kwargs
        Additional arguments to pass to the rioxarray clip function

    Returns
    -------
    clipped: xr.Dataset | xr.DataArray
        Returns same type as 'data', but grid is clipped to shapefile feature(s).

    """
    if data.rio.crs is None:
        raise RuntimeError(
            "No CRS found on input parameter 'data'. Use rioxarray write_crs() method to set CRS."
        )

    region = gpd.read_file(shapefile_path)

    if region.crs is None:
        raise RuntimeError(
            "No CRS found on data read from shapefile. Verify that shapefile contains valid CRS information."
        )

    # Select only user provided feature
    if feature:
        try:
            print("Selecting feature", feature)
            if isinstance(feature[1], list):
                region = region[region[feature[0]].isin(feature[1])]
            else:
                region = region[region[feature[0]] == feature[1]]
            if len(region) == 0:  # No features found
                raise ValueError("None of the requested features were found.")
        except ValueError as err:
            raise err
        except Exception as err:
            raise RuntimeError(
                "Could not select one or more feature(s) {0} in {1} ".format(
                    feature, shapefile_path
                )
            ) from err

    try:
        clipped = data.rio.clip(
            region.geometry.apply(mapping), region.crs, drop=True, **kwargs
        )
    except rio.exceptions.NoDataInBounds as err:
        msg = "Can't clip feature. Your grid resolution may be too low for your shapefile feature, or your shapefile's CRS may be incorrectly set."
        raise RuntimeError(msg) from err
    except Exception as err:
        raise err

    if feature:
        if isinstance(feature[1], list):
            location = [str(item) for item in feature[1]]
        else:
            location = [str(feature[1])]
    else:
        location = [name]
    clipped.attrs["location_subset"] = location

    return clipped

clip_gpd_to_shapefile(gdf, shapefile)

Use a shapefile to select an area subset of a geodataframe. Used to subset stationlist to shapefile area.

Parameters:

Name Type Description Default
gdf GeoDataFrame

Data to be clipped.

required
shapefile GeoDataFrame

Shapefile must include valid CRS.

required

Returns:

Name Type Description
clipped GeoDataFrame

Subsetted geodataframe within shapefile area of interest.

Source code in climakitae/util/utils.py
def clip_gpd_to_shapefile(
    gdf: gpd.GeoDataFrame,
    shapefile: gpd.GeoDataFrame,
) -> gpd.GeoDataFrame:
    """Use a shapefile to select an area subset of a geodataframe.
    Used to subset stationlist to shapefile area.

    Parameters
    ----------
    gdf : gpd.GeoDataFrame
        Data to be clipped.
    shapefile : gpd.GeoDataFrame
        Shapefile must include valid CRS.

    Returns
    -------
    clipped : gpd.GeoDataFrame
        Subsetted geodataframe within shapefile area of interest.
    """

    # Adds coordinates
    geom = gpd.points_from_xy(gdf["longitude"], gdf["latitude"])
    sub_gdf = gpd.GeoDataFrame(gdf, geometry=geom).set_crs(
        crs="EPSG:3857", allow_override=True
    )

    # Check CRS
    if sub_gdf is None or shapefile.crs is None:
        raise RuntimeError(
            "Both input GeoDataFrame and shapefile must have a defined CRS."
        )

    if sub_gdf.crs != shapefile.crs:
        shapefile = shapefile.to_crs(sub_gdf.crs)

    # Subset for stations within area boundaries
    clipped = sub_gdf[sub_gdf.geometry.intersects(shapefile.union_all())]

    if clipped.empty:
        raise RuntimeError(
            "Clipping returned an empty GeoDataFrame; check geometries and CRS."
        )

    # Useful information
    print(f"Number of stations within area: {len(clipped)}")

    return clipped

Warming Level Utilities

Helper functions related to applying a warming levels approach to a data object

calculate_warming_level(warming_data, gwl_times, level, months, window)

Perform warming level computation for a single warming level.

Assumes the data has already been stacked by simulation and scenario to create a MultiIndex dimension "all_sims" and that the invalid simulations have been removed such that the gwl_times table can be adequately parsed.

Internal function only; see the function _apply_warming_levels_approach for more documentation on how this function is applied internally. Appropriate attributes for new dimensions are applied by the retrieval function (not here).

Parameters:

Name Type Description Default
warming_data DataArray

Data object returned by _get_data_one_var, stacked by simulation/scenario, and then with invalid simulations removed.

required
gwl_times DataFrame

Global warming levels table indicating when each unique model/run/scenario (simulation) reaches each warming level.

required
level float

Warming level. Must be a valid column in gwl_times table.

required
months list of int

Months of the year (in integers) to compute function for. For example, for a full year: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

required
window int

Years around Global Warming Level (+/-). For example, 15 means a 30-year window.

required

Returns:

Type Description
DataArray

The warming level subset data.

Source code in climakitae/util/warming_levels.py
def calculate_warming_level(
    warming_data: xr.DataArray,
    gwl_times: pd.DataFrame,
    level: float,
    months: list[int],
    window: int,
) -> xr.DataArray:
    """Perform warming level computation for a single warming level.

    Assumes the data has already been stacked by simulation and scenario to
    create a MultiIndex dimension "all_sims" and that the invalid simulations
    have been removed such that the gwl_times table can be adequately parsed.

    Internal function only; see the function _apply_warming_levels_approach for
    more documentation on how this function is applied internally. Appropriate
    attributes for new dimensions are applied by the retrieval function
    (not here).

    Parameters
    ----------
    warming_data : xr.DataArray
        Data object returned by _get_data_one_var, stacked by
        simulation/scenario, and then with invalid simulations removed.
    gwl_times : pd.DataFrame
        Global warming levels table indicating when each unique
        model/run/scenario (simulation) reaches each warming level.
    level : float
        Warming level. Must be a valid column in gwl_times table.
    months : list of int
        Months of the year (in integers) to compute function for.
        For example, for a full year: ``[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]``
    window : int
        Years around Global Warming Level (+/-). For example, 15 means a
        30-year window.

    Returns
    -------
    xr.DataArray
        The warming level subset data.
    """
    # Raise error if proper processing has not been performed on the data before calling the function
    if "all_sims" not in warming_data.dims:
        raise AttributeError(
            "Missing an `all_sims` dimension on the dataset. Create `all_sims` with .stack on `simulation` and `scenario`."
        )

    # Apply _get_sliced_data function by simulation dimension
    warming_data = warming_data.groupby("all_sims").map(
        _get_sliced_data, level=level, gwl_times=gwl_times, months=months, window=window
    )

    warming_data = warming_data.expand_dims({"warming_level": [level]})

    # Check that there exist simulations that reached this warming level before cleaning. Otherwise, don't modify anything.
    if not (warming_data.centered_year.isnull()).all():
        # Removing simulations where this warming level is not crossed. (centered_year)
        warming_data = warming_data.sel(all_sims=~warming_data.centered_year.isnull())

    return warming_data

drop_invalid_sims(ds, selections)

As part of the warming levels calculation, the data is stacked by simulation and scenario, creating some empty values for that coordinate. Here, we remove those empty coordinate values.

Parameters:

Name Type Description Default
ds Dataset

The dataset must have a dimension all_sims that results from stacking simulation and scenario.

required
selections DataParameters

Warming level data selections

required

Returns:

Type Description
Dataset

The dataset with only valid simulations retained.

Raises:

Type Description
AttributeError

If the dataset does not have an all_sims dimension.

Source code in climakitae/util/warming_levels.py
def drop_invalid_sims(ds: xr.Dataset, selections) -> xr.Dataset:
    """As part of the warming levels calculation, the data is stacked by simulation and scenario, creating some empty values for that coordinate.
    Here, we remove those empty coordinate values.

    Parameters
    ----------
    ds : xr.Dataset
        The dataset must have a
        dimension `all_sims` that results from stacking `simulation` and
        `scenario`.
    selections : DataParameters
        Warming level data selections

    Returns
    -------
    xr.Dataset
        The dataset with only valid simulations retained.

    Raises
    ------
    AttributeError
        If the dataset does not have an `all_sims` dimension.

    """
    df = _get_cat_subset(selections).df

    # Just trying to see simulations across SSPs, not including historical period
    filter_df = df[df["experiment_id"] != "historical"]

    # Creating a valid simulation list to filter the original dataset from
    valid_sim_list = list(
        zip(
            filter_df["activity_id"]
            + "_"
            + filter_df["source_id"]
            + "_"
            + filter_df["member_id"],
            filter_df["experiment_id"].apply(
                lambda val: f"Historical + {scenario_to_experiment_id(val, reverse=True)}"
            ),
        )
    )
    return ds.sel(all_sims=valid_sim_list)

read_warming_level_csvs()

Reads two CSV files containing global warming level (GWL) data.

Returns:

Type Description
tuple[DataFrame, DataFrame]

df : pd.DataFrame Time-indexed DataFrame (time as index, simulations as columns). other_df : pd.DataFrame DataFrame with warming levels per simulation (no datetime index).

Source code in climakitae/util/warming_levels.py
def read_warming_level_csvs() -> tuple[pd.DataFrame, pd.DataFrame]:
    """Reads two CSV files containing global warming level (GWL) data.

    Returns
    -------
    tuple[pd.DataFrame, pd.DataFrame]
        df : pd.DataFrame
            Time-indexed DataFrame (time as index, simulations as columns).
        other_df : pd.DataFrame
            DataFrame with warming levels per simulation (no datetime index).

    """
    df = read_csv_file(GWL_1850_1900_TIMEIDX_FILE, index_col="time", parse_dates=True)
    other_df = read_csv_file(GWL_1850_1900_FILE)
    return df, other_df

get_wl_timestamp(series, degree)

Finds the first timestamp when the series crosses the specified warming level.

Parameters:

Name Type Description Default
series Series

A time-indexed warming level series.

required
degree float

Target warming level.

required

Returns:

Type Description
str | float

Timestamp as string if crossed, else np.nan.

Source code in climakitae/util/warming_levels.py
def get_wl_timestamp(series: pd.Series, degree: float) -> Union[str, float]:
    """Finds the first timestamp when the series crosses the specified warming level.

    Parameters
    ----------
    series : pd.Series
        A time-indexed warming level series.
    degree : float
        Target warming level.

    Returns
    -------
    str | float
        Timestamp as string if crossed, else np.nan.

    """
    if any(series >= degree):
        return series[series >= degree].index[0].strftime("%Y-%m-%d %H:%M")
    return np.nan

create_new_warming_level_table(warming_level)

Returns a table of timestamps when each simulation reaches the given warming level.

Parameters:

Name Type Description Default
warming_level float

New WL to retrieve WL timing for.

required

Returns:

Type Description
pd.DataFrame

Same DataFrame as data/gwl_1850-1900ref.csv, just with a new WL columns with the warming_level arg passed.

Source code in climakitae/util/warming_levels.py
def create_new_warming_level_table(warming_level: float) -> pd.DataFrame:
    """Returns a table of timestamps when each simulation reaches the given warming level.

    Parameters
    ----------
    warming_level : float
        New WL to retrieve WL timing for.

    Returns
    -------
        pd.DataFrame
            Same DataFrame as `data/gwl_1850-1900ref.csv`, just with a new WL columns with the `warming_level` arg passed.

    """
    df, other_df = read_warming_level_csvs()

    # Map each simulation to its crossing timestamp for the given warming level
    wl_timestamps = {
        col: get_wl_timestamp(df[col], warming_level) for col in df.columns
    }

    result = other_df.copy(deep=True)
    result["sim"] = result["GCM"] + "_" + result["run"] + "_" + result["scenario"]
    timestamp_series = pd.Series(wl_timestamps)

    result[str(warming_level)] = result["sim"].map(timestamp_series)
    result = result.drop(columns="sim")
    result = result.set_index(["GCM", "run", "scenario"])

    return result

filter_warming_trajectories_to_ae(simulations_df, warming_trajectories, downscaling_method)

Filters all simulations in warming_trajectories to only the ones we have on AE (simulations_df). Does this filtering by downscaling_method as well.

Parameters:

Name Type Description Default
simulations_df DataFrame

Complete simulation dataframe of all simulations in GWL tables.

required
warming_trajectories DataFrame

Full warming trajectory DataFrame, computed from read_warming_level_csvs.

required
downscaling_method str

Downscaling method to filter DataFrame by ('LOCA' or 'WRF').

required

Returns:

Type Description
DataFrame

Filtered simulations_df to only simulations accessible on AE.

Source code in climakitae/util/warming_levels.py
def filter_warming_trajectories_to_ae(
    simulations_df: pd.DataFrame,
    warming_trajectories: pd.DataFrame,
    downscaling_method: str,
) -> pd.DataFrame:
    """Filters all simulations in `warming_trajectories` to only the ones we have on AE (`simulations_df`).
    Does this filtering by `downscaling_method` as well.

    Parameters
    ----------
    simulations_df : pd.DataFrame
        Complete simulation dataframe of all simulations in GWL tables.
    warming_trajectories : pd.DataFrame
        Full warming trajectory DataFrame, computed from `read_warming_level_csvs`.
    downscaling_method : str
        Downscaling method to filter DataFrame by ('LOCA' or 'WRF').

    Returns
    -------
    pd.DataFrame
        Filtered `simulations_df` to only simulations accessible on AE.

    """
    columns_to_keep = []
    activity_simulations = simulations_df[
        simulations_df["activity_id"] == downscaling_method
    ]

    for _, row in activity_simulations.iterrows():
        pattern = f"{row['source_id']}_{row['member_id']}_{row['experiment_id']}"
        matches = [col for col in warming_trajectories.columns if pattern in col]
        columns_to_keep.extend(matches)

    return warming_trajectories[columns_to_keep]

create_ae_warming_trajectories(resolution)

Creates warming trajectories for all AE simulations based on a given resolution. This resolution is an important parameter because not all resolutions have the same number of WRF simulations (i.e. 3km has 8 but 9km has 10).

Parameters:

Name Type Description Default
resolution str

Grid resolution (e.g., "6km", "12km").

required

Returns:

Type Description
tuple[DataFrame, DataFrame]

LOCA2 warming trajectories (pd.DataFrame) WRF warming trajectories (pd.DataFrame)

Source code in climakitae/util/warming_levels.py
def create_ae_warming_trajectories(
    resolution: str,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Creates warming trajectories for all AE simulations based on a given resolution.
    This resolution is an important parameter because not all resolutions have the same number of WRF simulations (i.e. 3km has 8 but 9km has 10).

    Parameters
    ----------
    resolution : str
        Grid resolution (e.g., "6km", "12km").

    Returns
    -------
    tuple[pd.DataFrame, pd.DataFrame]
        LOCA2 warming trajectories (pd.DataFrame)
        WRF warming trajectories (pd.DataFrame)

    """
    df = intake.open_esm_datastore(DATA_CATALOG_URL).df
    grid_label = resolution_to_gridlabel(resolution)

    # Only select simulations with the given grid label, since WRF has a different number of simulations depending on the spatial resolution
    select_sims = df[df["grid_label"] == grid_label]

    simulations_df = (
        select_sims[["activity_id", "source_id", "experiment_id", "member_id"]]
        .drop_duplicates()
        .reset_index(drop=True)
    )
    warming_trajectories, _ = read_warming_level_csvs()

    loca2 = filter_warming_trajectories_to_ae(
        simulations_df, warming_trajectories, "LOCA2"
    )
    wrf = filter_warming_trajectories_to_ae(simulations_df, warming_trajectories, "WRF")

    return loca2, wrf

generate_ssp_dict()

Loads historical and SSP scenario CSVs into one dictionary.

Returns:

Type Description
Dict[str, pd.DataFrame] : A dictionary mapping scenario names to their

pandas DataFrames, indexed by year.

Source code in climakitae/util/warming_levels.py
def generate_ssp_dict() -> dict[str, pd.DataFrame]:
    """Loads historical and SSP scenario CSVs into one dictionary.

    Returns
    -------
    Dict[str, pd.DataFrame] : A dictionary mapping scenario names to their
        pandas DataFrames, indexed by year.

    """
    files_dict = {
        "Historical": HIST_FILE,
        "SSP 1-1.9": SSP119_FILE,
        "SSP 1-2.6": SSP126_FILE,
        "SSP 2-4.5": SSP245_FILE,
        "SSP 3-7.0": SSP370_FILE,
        "SSP 5-8.5": SSP585_FILE,
    }
    return {
        ssp_str: read_csv_file(filename, index_col="Year")
        for ssp_str, filename in files_dict.items()
    }

get_gwl_at_year(year, ssp='all')

Retrieve estimated Global Warming Level (GWL) statistics for a given year.

Parameters:

Name Type Description Default
year int

The year for which to retrieve GWL estimates.

required
ssp str

The SSP scenario to use. Use 'all' to retrieve results for all SSPs.

'all'

Returns:

Type Description
DataFrame

A DataFrame with SSPs as rows and '5%', 'Mean', and '95%' as columns, containing the warming level estimates for the specified year.

Source code in climakitae/util/warming_levels.py
def get_gwl_at_year(year: int, ssp: str = "all") -> pd.DataFrame:
    """Retrieve estimated Global Warming Level (GWL) statistics for a given year.

    Parameters
    ----------
    year : int
        The year for which to retrieve GWL estimates.
    ssp : str, default='all'
        The SSP scenario to use. Use 'all' to retrieve results for all SSPs.

    Returns
    -------
    pd.DataFrame
        A DataFrame with SSPs as rows and '5%', 'Mean', and '95%' as columns,
        containing the warming level estimates for the specified year.

    """
    ssp_dict = generate_ssp_dict()
    wl_timing_df = pd.DataFrame(columns=["5%", "Mean", "95%"])

    if year >= 2015:
        ssp_list = (
            ["SSP 1-1.9", "SSP 1-2.6", "SSP 2-4.5", "SSP 3-7.0", "SSP 5-8.5"]
            if ssp == "all"
            else [ssp]
        )
        # Find the data for the given year and different scenarios
        for scenario in ssp_list:
            wl_by_year_for_scenario = ssp_dict.get(scenario)
            if year not in wl_by_year_for_scenario.index:
                print(f"Year {year} not found in {scenario}")
                wl_timing_df.loc[scenario] = [np.nan, np.nan, np.nan]
            else:
                wl_timing_df.loc[scenario] = round(wl_by_year_for_scenario.loc[year], 2)

    else:
        # Finding the data from the historical period
        if ssp != "all":
            print(f"Year {year} before 2015, using Historical data")
        hist_data = ssp_dict["Historical"]

        if year not in hist_data.index:
            print(f"Year {year} not found in Historical")
            wl_timing_df.loc["Historical"] = [np.nan, np.nan, np.nan]
        else:
            wl_timing_df.loc["Historical"] = round(hist_data.loc[year], 2)

    return wl_timing_df

get_year_at_gwl(gwl, ssp='all')

Retrieve the year when a given Global Warming Level (GWL) is reached for each SSP scenario.

Parameters:

Name Type Description Default
gwl nan | int

The Global Warming Level to check (e.g., 1.5, 2.0).

required
ssp str

The SSP scenario to evaluate. Use 'all' to check across all SSPs and the Historical period.

'all'

Returns:

Type Description
DataFrame

A DataFrame with SSPs as rows and columns ['5%', 'Mean', '95%'] indicating the years when each warming level threshold is crossed for the respective uncertainty bounds. NaN indicates the level was not reached by 2100.

Source code in climakitae/util/warming_levels.py
def get_year_at_gwl(gwl: Union[float, int], ssp: str = "all") -> pd.DataFrame:
    """Retrieve the year when a given Global Warming Level (GWL) is reached for each SSP scenario.

    Parameters
    ----------
    gwl : np.nan | int
        The Global Warming Level to check (e.g., 1.5, 2.0).
    ssp : str, default='all'
        The SSP scenario to evaluate. Use 'all' to check across all SSPs and the Historical period.

    Returns
    -------
    pd.DataFrame
        A DataFrame with SSPs as rows and columns ['5%', 'Mean', '95%'] indicating the years
        when each warming level threshold is crossed for the respective uncertainty bounds.
        NaN indicates the level was not reached by 2100.

    """
    ssp_dict = generate_ssp_dict()

    wl_timing_df = pd.DataFrame(columns=["5%", "Mean", "95%"])

    ssp_list = (
        ["Historical", "SSP 1-1.9", "SSP 1-2.6", "SSP 2-4.5", "SSP 3-7.0", "SSP 5-8.5"]
        if ssp == "all"
        else [ssp]
    )

    for ssp in ssp_list:
        ssp_selected = ssp_dict[ssp]

        mean_mask = ssp_selected["Mean"] > gwl
        upper_mask = ssp_selected["95%"] > gwl
        lower_mask = ssp_selected["5%"] > gwl

        def first_wl_year(one_ssp: pd.Series, mask: pd.Series) -> Union[int, np.nan]:
            """Return the first year where the pd.Series mask is True, or NaN if none."""
            if mask.any():
                return round(one_ssp.index[mask][0], 0)
            else:
                return np.nan

        # Only add data for a scenario if the mean and upper bound of uncertainty reach the gwl
        if mean_mask.any() and upper_mask.any() and (not mean_mask.all()):
            year_gwl_reached = first_wl_year(ssp_selected, mean_mask)
            x_95 = first_wl_year(ssp_selected, lower_mask)

            # If the lower bound is outside the range of the ssp, use the historical data
            if upper_mask.all():
                x_5 = first_wl_year(
                    ssp_dict["Historical"], (ssp_dict["Historical"]["95%"] > gwl)
                )
            else:
                x_5 = first_wl_year(ssp_selected, upper_mask)

        else:
            x_5 = x_95 = year_gwl_reached = np.nan

        wl_timing_df.loc[ssp] = [x_5, year_gwl_reached, x_95]

    try:
        wl_timing_df = wl_timing_df.astype("Int64")
    except Exception:
        warnings.warn(
            "Error converting years to int, data may have unexpected issues.",
            UserWarning,
        )
        pass

    return wl_timing_df

Unit Conversions

Calculates alternative units for variables with multiple commonly used units, following NWS conversions for pressure and wind speed.

get_unit_conversion_options()

Get dictionary of unit conversion options offered for each unit

Source code in climakitae/util/unit_conversions.py
def get_unit_conversion_options() -> dict:
    """Get dictionary of unit conversion options offered for each unit"""
    options = {
        "K": ["K", "degC", "degF"],
        "degF": ["K", "degC", "degF"],
        "degC": ["K", "degC", "degF"],
        "hPa": ["Pa", "hPa", "mb", "inHg"],
        "Pa": ["Pa", "hPa", "mb", "inHg"],
        "m/s": ["m/s", "mph", "knots"],
        "m s-1": ["m s-1", "mph", "knots"],
        "[0 to 100]": ["[0 to 100]", "fraction"],
        "mm": ["mm", "inches"],
        "mm/d": ["mm/d", "inches/d"],
        "mm/h": ["mm/h", "inches/h"],
        "kg/kg": ["kg/kg", "g/kg"],
        "kg kg-1": ["kg kg-1", "g kg-1"],
        "kg m-2 s-1": ["kg m-2 s-1", "mm", "inches"],
        "g/kg": ["g/kg", "kg/kg"],
    }
    return options

convert_units(da, selected_units)

Converts units for any variable

Parameters:

Name Type Description Default
da DataArray

data

required
selected_units str

selected units of data, from selections.units

required

Returns:

Name Type Description
da DataArray

data with converted units and updated units attribute

References

Wind speed: https://www.weather.gov/media/epz/wxcalc/windConversion.pdf Pressure: https://www.weather.gov/media/epz/wxcalc/pressureConversion.pdf

Source code in climakitae/util/unit_conversions.py
def convert_units(da: xr.DataArray, selected_units: str) -> xr.DataArray:
    """Converts units for any variable

    Parameters
    ----------
    da : xr.DataArray
        data
    selected_units : str
        selected units of data, from selections.units

    Returns
    -------
    da : xr.DataArray
        data with converted units and updated units attribute

    References
    ----------
    Wind speed: https://www.weather.gov/media/epz/wxcalc/windConversion.pdf
    Pressure: https://www.weather.gov/media/epz/wxcalc/pressureConversion.pdf

    """

    # Get native units of data from attributes
    try:
        native_units = da.attrs["units"]
    except:
        raise ValueError(
            (
                "You've encountered a bug in the code. This variable "
                "does not have identifiable native units. The data"
                " for this variable will need to have a 'units'"
                " attribute added in the catalog."
            )
        )

    # Convert hPa to Pa to make conversions easier
    # Monthly data native unit is hPa, hourly is Pa
    if native_units == "hPa" and selected_units != "hPa":
        da = da / 100.0
        da.attrs["units"] = "Pa"
        native_units = "Pa"

    # Pass if chosen units is the same as native units
    match native_units:
        case native_units if native_units == selected_units:
            return da

        # Precipitation units
        case "mm" | "mm/d" | "mm/h":
            match selected_units:
                case "inches" | "inches/h" | "inches/d":
                    da = da / 25.4
                case "kg m-2 s-1":
                    da = da / 86400
                    if da.attrs["frequency"] == "monthly":
                        da_name = da.name
                        da = da / da["time"].dt.days_in_month
                        da.name = da_name  # Name is lost during computation above
        case "kg m-2 s-1":
            if da.attrs["frequency"] == "monthly":
                da_name = da.name
                da = da * da["time"].dt.days_in_month
                da.name = da_name  # Name is lost during computation above
            match selected_units:
                case "mm":
                    da = da * 86400
                case "inches":
                    da = (da * 86400) / 25.4

        # Moisture ratio units
        case "kg/kg" | "kg kg-1":
            if selected_units in ["g/kg", "g kg-1"]:
                da = da * 1000

        # Specific humidity
        case "g/kg":
            if selected_units == "kg/kg":
                da = da / 1000
                da.attrs["units"] = selected_units

        # Pressure units
        case "Pa":
            match selected_units:
                case "hPa":
                    da = da / 100.0
                case "mb":
                    da = da / 100.0
                case "inHg":
                    da = da * 0.000295300

        # Wind units
        case "m/s" | "m s-1":
            match selected_units:
                case "knots":
                    da = da * 1.9438445
                case "mph":
                    da = da * 2.236936

        # Temperature units
        case "K":
            match selected_units:
                case "degC":
                    da = da - 273.15
                case "degF":
                    da = (1.8 * (da - 273.15)) + 32
        case "degC":
            match selected_units:
                case "K":
                    da = da + 273.15
                case "degF":
                    da = (1.8 * da) + 32
        case "degF":
            # Convert to C
            match selected_units:
                case "degC":
                    da = (da - 32) / 1.8
                # Then, if K is selected, convert to K
                case "K":
                    da = (da - 32) / 1.8
                    da = da + 273.15

        # Fraction/percentage units (relative humidity)
        case "[0 to 100]":
            if selected_units == "fraction":
                da = da / 100.0

    # Update unit attribute to reflect converted unit
    da.attrs["units"] = selected_units
    return da

Colormaps

read_ae_colormap(cmap='ae_orange', cmap_hex=False)

Read in AE colormap by name

Parameters:

Name Type Description Default
cmap str

one of ["ae_orange", "ae_diverging", "ae_blue", "ae_diverging_r", "categorical_cb"]

'ae_orange'
cmap_hex boolean

return RGB or hex colors?

False

Returns:

Name Type Description
one of either
cmap_data LinearSegmentedColormap

used for matplotlib (if cmap_hex == False)

cmap_data list

used for hvplot maps (if cmap_hex == True)

Source code in climakitae/util/colormap.py
def read_ae_colormap(
    cmap: str = "ae_orange", cmap_hex: bool = False
) -> matplotlib.colors.LinearSegmentedColormap | list:
    """Read in AE colormap by name

    Parameters
    ----------
    cmap : str
        one of ["ae_orange", "ae_diverging", "ae_blue", "ae_diverging_r", "categorical_cb"]
    cmap_hex : boolean
        return RGB or hex colors?

    Returns
    -------
    one of either

    cmap_data : matplotlib.colors.LinearSegmentedColormap
        used for matplotlib (if cmap_hex == False)
    cmap_data : list
        used for hvplot maps (if cmap_hex == True)

    """

    match cmap:
        case "ae_orange":
            cmap_data = AE_ORANGE
        case "ae_diverging":
            cmap_data = AE_DIVERGING
        case "ae_blue":
            cmap_data = AE_BLUE
        case "ae_diverging_r":
            cmap_data = AE_DIVERGING_R
        case "categorical_cb":
            cmap_data = CATEGORICAL_CB
        case _:
            raise ValueError(
                'cmap needs to be one of ["ae_orange", "ae_diverging", "ae_blue", "ae_diverging_r", "categorical_cb"]'
            )

    # Load text file
    cmap_np = np.loadtxt(_package_file_path(cmap_data), dtype=float)

    # RBG to hex
    if cmap_hex:
        cmap_data = [matplotlib.colors.rgb2hex(color) for color in cmap_np]
    else:
        cmap_data = mcolors.LinearSegmentedColormap.from_list(cmap, cmap_np, N=256)
    return cmap_data

Cluster Management

Cluster

Bases: GatewayCluster

A dask-gateway cluster allowing one cluster per user. Instead of always creating new clusters, connect to a previously running user cluster, and attempt to limit users to a single cluster.

Methods:

Name Description
get_client

Get a dask client connected to the cluster.

Examples:

>>> from climakitae.util.cluster import Cluster
>>> cluster = Cluster() # Create cluster
>>> cluster.adapt(minimum=0, maximum=8) # Specify the number of workers to use
>>> client = cluster.get_client()
>>> cluster # Output cluster information

get_client(set_as_default=True)

Get client

Returns:

Type Description
Client
Source code in climakitae/util/cluster.py
def get_client(self, set_as_default: bool = True) -> Client:
    """Get client

    Returns
    -------
    distributed.client.Client

    """
    clusters = self.gateway.list_clusters()
    if clusters:
        cluster = self.gateway.connect(clusters.pop().name, shutdown_on_close=True)
        for c in clusters:
            self.gateway.stop_cluster(c.name)
        client = cluster.get_client()
    else:
        client = super().get_client(set_as_default)
    plugin = PipInstall(packages=self.extra_packages)
    client.register_worker_plugin(plugin)
    return client

Logger Configuration

enable_app_logging()

Enables application logging by adding a handler.

Source code in climakitae/util/logger.py
def enable_app_logging():
    """Enables application logging by adding a handler."""
    global app_log_enabled
    if not app_log_enabled:
        app_log_enabled = True
        logger.addHandler(handler)

disable_app_logging()

Disables logging by removing application logger handler.

Source code in climakitae/util/logger.py
def disable_app_logging():
    """Disables logging by removing application logger handler."""
    global app_log_enabled
    if app_log_enabled:
        app_log_enabled = False
        logger.removeHandler(handler)