Destaggering and vertically interpolating dask data using xgcm

In this tutorial, we will show you how to leverage the xWRF accessors and the xWRF-provided COMODO-compliant attributes in order to destagger the WRF output and interpolate it vertically using dask and xgcm.

Loading the data

First of all, we load the data and use the simple .xwrf.postprocess() API and dask-enable the dataset by passing open_dataset the chunkgs kwarg. In a real-world scenario, you might want to spawn a Cluster in order to speed up calculations.

import xwrf
ds = xwrf.tutorial.open_dataset("wrfout", chunks='auto').xwrf.postprocess()
ds
/home/docs/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
<xarray.Dataset> Size: 104MB
Dimensions:                    (Time: 1, z: 39, y: 340, x_stag: 271,
                                y_stag: 341, x: 270)
Coordinates: (12/13)
  * Time                       (Time) datetime64[us] 8B 2099-10-01
    XTIME                      (Time) datetime64[ns] 8B dask.array<chunksize=(1,), meta=np.ndarray>
  * z                          (z) float32 156B 0.9969 0.9899 ... 0.002948
  * y                          (y) float64 3kB -3.341e+05 ... 2.717e+06
  * x_stag                     (x_stag) float64 2kB -4.733e+06 ... -2.303e+06
    XLAT_U                     (y, x_stag) float32 369kB dask.array<chunksize=(340, 271), meta=np.ndarray>
    ...                         ...
  * y_stag                     (y_stag) float64 3kB -3.386e+05 ... 2.721e+06
  * x                          (x) float64 2kB -4.728e+06 ... -2.307e+06
    XLAT                       (y, x) float32 367kB dask.array<chunksize=(340, 270), meta=np.ndarray>
    XLONG                      (y, x) float32 367kB dask.array<chunksize=(340, 270), meta=np.ndarray>
    XLAT_V                     (y_stag, x) float32 368kB dask.array<chunksize=(341, 270), meta=np.ndarray>
    XLONG_V                    (y_stag, x) float32 368kB dask.array<chunksize=(341, 270), meta=np.ndarray>
Data variables:
    Times                      (Time) |S19 19B dask.array<chunksize=(1,), meta=np.ndarray>
    U                          (Time, z, y, x_stag) float32 14MB dask.array<chunksize=(1, 39, 340, 271), meta=np.ndarray>
    V                          (Time, z, y_stag, x) float32 14MB dask.array<chunksize=(1, 39, 341, 270), meta=np.ndarray>
    SINALPHA                   (Time, y, x) float32 367kB dask.array<chunksize=(1, 340, 270), meta=np.ndarray>
    COSALPHA                   (Time, y, x) float32 367kB dask.array<chunksize=(1, 340, 270), meta=np.ndarray>
    QVAPOR                     (Time, z, y, x) float32 14MB dask.array<chunksize=(1, 39, 340, 270), meta=np.ndarray>
    PSN                        (Time, y, x) float32 367kB dask.array<chunksize=(1, 340, 270), meta=np.ndarray>
    air_potential_temperature  (Time, z, y, x) float32 14MB dask.array<chunksize=(1, 39, 340, 270), meta=np.ndarray>
    air_pressure               (Time, z, y, x) float32 14MB dask.array<chunksize=(1, 39, 340, 270), meta=np.ndarray>
    wind_east                  (Time, z, y, x) float32 14MB dask.array<chunksize=(1, 39, 340, 270), meta=np.ndarray>
    wind_north                 (Time, z, y, x) float32 14MB dask.array<chunksize=(1, 39, 340, 270), meta=np.ndarray>
    wrf_projection             object 8B +proj=lcc +x_0=0 +y_0=0 +a=6370000 +...
Attributes: (12/149)
    TITLE:                            OUTPUT FROM WRF V4.1.3 MODEL
    START_DATE:                      2099-08-01_00:00:00
    SIMULATION_START_DATE:           2099-08-01_00:00:00
    WEST-EAST_GRID_DIMENSION:        271
    SOUTH-NORTH_GRID_DIMENSION:      341
    BOTTOM-TOP_GRID_DIMENSION:       40
    ...                              ...
    ISLAKE:                          21
    ISICE:                           15
    ISURBAN:                         13
    ISOILWATER:                      14
    HYBRID_OPT:                      0
    ETAC:                            0.0

Destaggering

If we naively try to calculate the wind speed from the U and V components, we get an error due to them having different shapes.

from metpy.calc import wind_speed

wind_speed(ds.U, ds.V)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 3
      1 from metpy.calc import wind_speed
      2 
----> 3 wind_speed(ds.U, ds.V)

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/metpy/xarray.py:1330, in preprocess_and_wrap.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
   1327     _mutate_arguments(bound_args, units.Quantity, lambda arg, _: arg.m)
   1329 # Evaluate inner calculation
-> 1330 result = func(*bound_args.args, **bound_args.kwargs)
   1332 # Wrap output based on match and match_unit
   1333 if match is None:

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/metpy/units.py:337, in check_units.<locals>.dec.<locals>.wrapper(*args, **kwargs)
    334 @functools.wraps(func)
    335 def wrapper(*args, **kwargs):
    336     _check_units_inner_helper(func, sig, defaults, dims, *args, **kwargs)
--> 337     return func(*args, **kwargs)

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/metpy/calc/basic.py:63, in wind_speed(u, v)
     33 @exporter.export
     34 @preprocess_and_wrap(wrap_like='u')
     35 @check_units('[speed]', '[speed]')
     36 def wind_speed(u, v):
     37     r"""Compute the wind speed from u and v-components.
     38 
     39     Parameters
   (...)     61 
     62     """
---> 63     return np.hypot(u, v)

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/pint/facets/numpy/quantity.py:72, in NumpyQuantity.__array_ufunc__(self, ufunc, method, *inputs, **kwargs)
     65 # Replicate types from __array_function__
     66 types = {
     67     type(arg)
     68     for arg in list(inputs) + list(kwargs.values())
     69     if hasattr(arg, "__array_ufunc__")
     70 }
---> 72 return numpy_wrap("ufunc", ufunc, inputs, kwargs, types)

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/pint/facets/numpy/numpy_func.py:1133, in numpy_wrap(func_type, func, args, kwargs, types)
   1131 if name not in handled or any(is_upcast_type(t) for t in types):
   1132     return NotImplemented
-> 1133 return handled[name](*args, **kwargs)

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/pint/facets/numpy/numpy_func.py:324, in implement_func.<locals>.implementation(*args, **kwargs)
    319     stripped_args, stripped_kwargs = convert_to_consistent_units(
    320         *args, pre_calc_units=pre_calc_units, **kwargs
    321     )
    323 # Determine result through plain numpy function on stripped arguments
--> 324 result_magnitude = func(*stripped_args, **stripped_kwargs)
    326 if output_unit is None:
    327     # Short circuit and return magnitude alone
    328     return result_magnitude

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/dask/array/core.py:1597, in Array.__array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs)
   1595         return da_ufunc(*inputs, **kwargs)
   1596     else:
-> 1597         return elemwise(numpy_ufunc, *inputs, **kwargs)
   1598 elif method == "outer":
   1599     from dask.array import ufunc

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/dask/array/core.py:5152, in elemwise(op, out, where, dtype, name, *args, **kwargs)
   5148     shapes.append(out.shape)
   5150 shapes = [s if isinstance(s, Iterable) else () for s in shapes]
   5151 out_ndim = len(
-> 5152     broadcast_shapes(*shapes)
   5153 )  # Raises ValueError if dimensions mismatch
   5154 expr_inds = tuple(range(out_ndim))[::-1]
   5156 if dtype is not None:

File ~/checkouts/readthedocs.org/user_builds/xwrf/conda/latest/lib/python3.12/site-packages/dask/array/core.py:5080, in broadcast_shapes(*shapes)
   5078         dim = 0 if 0 in sizes else np.max(sizes).item()
   5079     if any(i not in [-1, 0, 1, dim] and not np.isnan(i) for i in sizes):
-> 5080         raise ValueError(
   5081             "operands could not be broadcast together with "
   5082             "shapes {}".format(" ".join(map(str, shapes)))
   5083         )
   5084     out.append(dim)
   5085 return tuple(reversed(out))

ValueError: operands could not be broadcast together with shapes (1, 39, 340, 271) (1, 39, 341, 270)

Upon investigating the wind components, we can see that they are defined on the WRF-internal Arakawa-C grid, which causes the shapes to differ.

ds.U.sizes, ds.V.sizes
(Frozen({'Time': 1, 'z': 39, 'y': 340, 'x_stag': 271}),
 Frozen({'Time': 1, 'z': 39, 'y_stag': 341, 'x': 270}))

Destaggering is done in no time at all using the handy .xwrf accessor. We can now decide whether to destagger the whole Dataset

destaggered = ds.xwrf.destagger().metpy.quantify()
destaggered['wind_speed'] = wind_speed(destaggered.U, destaggered.V)
destaggered.wind_speed
<xarray.DataArray 'wind_speed' (Time: 1, z: 39, y: 340, x: 270)> Size: 14MB
<Quantity(dask.array<hypot, shape=(1, 39, 340, 270), dtype=float32, chunksize=(1, 39, 340, 270), chunktype=numpy.ndarray>, 'meter / second')>
Coordinates:
  * Time     (Time) datetime64[us] 8B 2099-10-01
    XTIME    (Time) datetime64[ns] 8B dask.array<chunksize=(1,), meta=np.ndarray>
  * z        (z) float32 156B 0.9969 0.9899 0.981 ... 0.0161 0.009174 0.002948
  * y        (y) float64 3kB -3.341e+05 -3.251e+05 ... 2.708e+06 2.717e+06
  * x        (x) float64 2kB -4.728e+06 -4.719e+06 ... -2.316e+06 -2.307e+06
    XLAT     (y, x) float32 367kB dask.array<chunksize=(340, 270), meta=np.ndarray>
    XLONG    (y, x) float32 367kB dask.array<chunksize=(340, 270), meta=np.ndarray>

… or whether we just want to destagger the two individual DataArrays.

ds = ds.metpy.quantify()
wind_speed(ds.U.xwrf.destagger(), ds.V.xwrf.destagger())
<xarray.DataArray 'hypot-e3c13d520042b90d13385aee5e7f5c70' (Time: 1, z: 39,
                                                            y: 340, x: 270)> Size: 14MB
<Quantity(dask.array<hypot, shape=(1, 39, 340, 270), dtype=float32, chunksize=(1, 39, 340, 270), chunktype=numpy.ndarray>, 'meter / second')>
Coordinates:
  * Time     (Time) datetime64[us] 8B 2099-10-01
    XTIME    (Time) datetime64[ns] 8B dask.array<chunksize=(1,), meta=np.ndarray>
  * z        (z) float32 156B 0.9969 0.9899 0.981 ... 0.0161 0.009174 0.002948
  * y        (y) float64 3kB -3.341e+05 -3.251e+05 ... 2.708e+06 2.717e+06
  * x        (x) float64 2kB -4.728e+06 -4.719e+06 ... -2.316e+06 -2.307e+06
    XLAT     (y, x) float32 367kB dask.array<chunksize=(340, 270), meta=np.ndarray>
    XLONG    (y, x) float32 367kB dask.array<chunksize=(340, 270), meta=np.ndarray>

Vertical interpolation using xgcm

We have now calculated the wind speed for the whole model domain. However, the z-layers are still in the native WRF sigma coordinate, which is of no practical use to us. So, in order to be able to analyze this data properly, we have to interpolate it onto pressure levels.

But, since xWRF prepared the dataset with the appropriate COMODO (and units) attributes, we can simply use xgcm with its Grid.transform function to solve this problem! However, since it doesn’t understand units yet, we have to work around it a bit:

import xgcm
import numpy as np
import pint_xarray

target_levels = np.array([250.]) # in hPa
air_pressure = destaggered.air_pressure.pint.to('hPa').metpy.dequantify()

grid = xgcm.Grid(destaggered, periodic=False)
_wind_speed = grid.transform(destaggered.wind_speed.metpy.dequantify(), 'Z', target_levels, target_data=air_pressure, method='log')
_wind_speed = _wind_speed.compute()

Finally, we can plot the result using hvplot.

import hvplot.xarray

_wind_speed.hvplot.quadmesh(
    x='XLONG',
    y='XLAT',
    title='Wind speed at 250 hPa',
    geo=True,
    project=True,
    alpha=0.9,
    cmap='inferno',
    clim=(_wind_speed.min().item(), _wind_speed.max().item()),
    clabel='wind speed [m/s]',
    tiles='OSM',
    dynamic=False
)