Source code for aris_lite.yield_expectation

#!/usr/bin/env python

"""Stress and yield estimation for ARIS-lite.

This module is the final stage of the ARIS-lite chain:
phenology provides crop activity (`Kc_factor`, `plant_height`), water budget
provides evapotranspiration and soil depletion, and meteorology provides
temperature. From these inputs, stress indicators are derived and translated
into crop-wise yield depression.
"""

__all__ = [
    "run_calc_stresses",
    "run_calc_yield",
    "main_heat_drought_compound_stress",
    "main_yield",
    "main_cli",
    "calc_stresses",
    "calc_heat_drought_compound_stress",
    "calc_yield",
]

from pathlib import Path
from typing import Iterable
import xarray as xr
import numpy as np

from aris_lite.deprecation import warn_legacy_python_api
from aris_lite.paths import (
    DEFAULT_BASE_DIR,
    input_taw,
    intermediate_stress,
    intermediate_year,
    output_year,
    reference_year,
)


def set_heat_drought_compound_stress_meta(da: xr.DataArray) -> xr.DataArray:
    """
    Set metadata for the crop heat-drought compound stress output.

    Renames the input to ``heat_drought_compound_stress`` and attaches
    descriptive attributes used in exported datasets.

    :param da: DataArray to annotate.
    :type da: xr.DataArray
    :return: Annotated DataArray with metadata.
    :rtype: xr.DataArray
    """
    return da.rename("heat_drought_compound_stress").assign_attrs(
        dict(
            units="",
            long_name="daily crop specific stress index based on maximum "
            "surface air temperature and soil water saturation",
            description="Index of combination of plant growth inhibiting "
            "factors. Used for yield estimation.",
        )
    )


[docs] def calc_heat_drought_compound_stress( ds: xr.Dataset, threshold_temperature=None ) -> xr.DataArray: """ Calculate crop-specific compound heat-drought stress. This stressor approximates daily crop damage pressure from concurrent heat and water limitations. Rules are crop-specific and based on the original ARIS logic: winter cereals and soybean combine water stress with heat exceedance, maize/barley/soybean apply a stronger thresholded response, potatoes use potato-specific heat thresholds, and grassland follows water stress only. The input ``waterstress`` is used when present; otherwise it is derived as ``soil_depletion / TAW`` for the top soil layer. Stress values before effective crop activity are masked using ``Kc_factor``. :param ds: Dataset containing ``max_air_temp`` and ``Kc_factor``, and either ``waterstress`` or ``soil_depletion`` with ``TAW``. ``plant_height`` is required for irrigated norm potato handling. :type ds: xr.Dataset :param threshold_temperature: Optional temperature threshold for potato crops other than ``norm potato`` and ``irrigated norm potato``. :type threshold_temperature: float, optional :return: Daily compound stress index for each crop. :rtype: xr.DataArray """ if "waterstress" not in ds: ds["waterstress"] = ( ds.soil_depletion.sel(layer="top") * 100 / ds.TAW.sel(layer="top") ) heat_drought_compound_stress = ds.waterstress.where(False) for i in range(heat_drought_compound_stress.shape[0]): if heat_drought_compound_stress[i].crop == "irrigated norm potato": heat_drought_compound_stress[i] = (ds.max_air_temp - 25).where( (ds.max_air_temp > 26) & (ds.plant_height > 0) ) if heat_drought_compound_stress[i].crop == "winter wheat": heat_drought_compound_stress[i] = xr.where( ds.max_air_temp > 26, ds.waterstress[i] * (ds.max_air_temp - 25), ds.waterstress[i], ) if heat_drought_compound_stress[i].crop in [ "spring barley", "maize", "soybean", ]: heat_drought_compound_stress[i] = xr.where( np.logical_and(ds.max_air_temp > 30, ds.waterstress[i] > 33), (ds.waterstress[i] * (ds.max_air_temp - 29)) - 33, ds.waterstress[i], ) if "potato" in heat_drought_compound_stress[i].crop.item(0): if heat_drought_compound_stress[i].crop == "norm potato": heat_drought_compound_stress[i] = xr.where( np.logical_and(ds.max_air_temp > 26, ds.waterstress[i] > 33), (ds.waterstress[i] * (ds.max_air_temp - 25)) - 33, ds.waterstress[i], ) else: if threshold_temperature is None: raise Exception( 'For potatoes, other than "(irrigated) norm potato", you must ' "pass the `threshold_temperature`." ) heat_drought_compound_stress[i] = xr.where( np.logical_and( ds.max_air_temp > threshold_temperature, ds.waterstress[i] > 33 ), ( ds.waterstress[i] * (ds.max_air_temp - (threshold_temperature - 1)) ) - 33, ds.waterstress[i], ) if heat_drought_compound_stress[i].crop == "grassland": heat_drought_compound_stress[i] = ds.waterstress[i] if heat_drought_compound_stress[i].crop in [ "winter wheat", "spring barley", "soybean", ]: heat_drought_compound_stress[i] = heat_drought_compound_stress[i].where( ds.time.dt.month >= 3 ) if "potato" in heat_drought_compound_stress[i].crop.item( 0 ): # actually 15.4. shouldn't matter heat_drought_compound_stress[i] = heat_drought_compound_stress[i].where( ds.time.dt.month >= 4 ) if heat_drought_compound_stress[i].crop in ["grassland", "maize"]: heat_drought_compound_stress[i] = heat_drought_compound_stress[i].where( ds.time.dt.month >= 5 ) heat_drought_compound_stress = heat_drought_compound_stress.where( (ds.Kc_factor > 0.5)[:, ::-1].cumsum("time") != 0 ) return heat_drought_compound_stress.rename("heat_drought_compound_stress")
def calc_drought_stress( evapotranspiration: xr.DataArray, pot_evapotransp: xr.DataArray ) -> xr.DataArray: """ Compute a daily drought-stress flag from actual vs. potential ET. Drought stress is ``True`` where the ratio ``sum(evapotranspiration over layer) / pot_evapotransp`` falls below ``0.4``. This represents severe water limitation relative to atmospheric demand. :param evapotranspiration: Actual evapotranspiration, including ``layer``. :type evapotranspiration: xr.DataArray :param pot_evapotransp: Potential evapotranspiration. :type pot_evapotransp: xr.DataArray :return: Boolean drought-stress flag named ``drought_stress``. :rtype: xr.DataArray """ return (evapotranspiration.sum("layer") / pot_evapotransp).rename( "drought_stress" ) < 0.4 def calc_heat_stress(max_air_temp: xr.DataArray) -> xr.DataArray: """ Compute a daily extreme-heat stress flag. Heat stress is ``True`` where maximum daily air temperature is at least ``32`` degree C. :param max_air_temp: Daily maximum air temperature. :type max_air_temp: xr.DataArray :return: Boolean heat-stress flag named ``heat_stress``. :rtype: xr.DataArray """ return max_air_temp.rename("heat_stress") >= 32
[docs] def calc_stresses(ds: xr.Dataset) -> xr.DataArray: """ Build the full stress tensor used by yield estimation. Combines simple threshold stressors (heat, drought) with the crop-specific compound heat-drought stress and stacks them along the ``stressor`` dimension. This is the direct input expected by ``calc_yield``. :param ds: Dataset with meteorological, phenology, and water-balance inputs. :type ds: xr.Dataset :return: DataArray named ``stresses`` with dimensions including ``stressor``. :rtype: xr.DataArray """ return xr.Dataset( { da.name: da for da in [ calc_heat_stress(ds.max_air_temp), calc_drought_stress(ds.evapotranspiration, ds.pot_evapotransp), calc_heat_drought_compound_stress(ds), ] } ).to_dataarray(dim="stressor", name="stresses")
[docs] def run_calc_stresses( years: Iterable[int], base_dir: str = str(DEFAULT_BASE_DIR), ): """ Compute and persist yearly heat-drought compound stress datasets. The output store path follows ``intermediate/CSI_<year>.zarr`` under ``base_dir``. """ taw = xr.open_dataarray(str(input_taw(base_dir=base_dir)), decode_coords="all") for year in years: stress_store = intermediate_stress(year, base_dir=base_dir) yearly_intermediate = intermediate_year(year, base_dir=base_dir) yearly_reference = reference_year(year, base_dir=base_dir) if stress_store.exists(): print(f"! WARNING: {stress_store} already exists. Skipping.") continue print("Calculating stress index for year", year) ds = xr.open_zarr(str(yearly_intermediate), decode_coords="all") if not hasattr(taw, "chunks"): taw = taw.chunk({k: ds.chunks[k] for k in taw.dims}) waterstress = ( (ds.soil_depletion * 100 / taw).sel(layer="top").rename("waterstress") ) max_air_temp = xr.open_zarr( str(yearly_reference), decode_coords="all" ).max_air_temp.astype("f4") data_collection = xr.merge( [waterstress, max_air_temp, ds.Kc_factor.astype("f4")] ) csi = set_heat_drought_compound_stress_meta( data_collection.map_blocks( calc_heat_drought_compound_stress, template=data_collection.Kc_factor.astype("f4"), ) ) csi.drop_encoding().to_zarr(str(stress_store), mode="a-")
def set_yield_meta(da: xr.DataArray) -> xr.DataArray: """ Set metadata for yield depression output. Renames the array to ``yield_depression`` and adds standard attributes. :param da: DataArray to annotate. :type da: xr.DataArray :return: Annotated DataArray with metadata. :rtype: xr.DataArray """ return da.rename("yield_depression").assign_attrs( dict( units="%", long_name="Yield depression", description="The yield depression given a certain combined stress which is " "crop specific and bases on water availability and heat above defined " "thresholds.", ) )
[docs] def calc_yield( stress_data: xr.DataArray, yield_max: xr.DataArray, yield_intercept: xr.DataArray, params: xr.DataArray, temporal_mask: xr.DataArray = xr.DataArray(True), grass_cut_numbers: xr.DataArray = xr.DataArray(0), ) -> xr.DataArray: """ Estimate yield depression from one or multiple stressors. The function applies a linear stress-response model where each stressor is weighted by a crop-specific parameter, added to an intercept, and scaled by maximum attainable yield. The result is expressed as percentage yield loss. If a ``time`` dimension exists, stress is first masked by ``temporal_mask``, cumulatively aggregated per year, and aligned back to the original time coordinate. Optional grass-cut information can be used to broadcast management-specific parameters and collapse ``grass1``, ``grass2``, and ``grassM`` into a single ``grassland`` class. :param stress_data: Stress input with a ``stressor`` dimension. :type stress_data: xr.DataArray :param yield_max: Maximum attainable yield (same spatial/crop support). :type yield_max: xr.DataArray :param yield_intercept: Baseline yield term used in the linear model. :type yield_intercept: xr.DataArray :param params: Stress-response coefficients, typically by ``stressor`` and crop (and optionally management class). :type params: xr.DataArray :param temporal_mask: Boolean mask selecting time steps that contribute to cumulative stress. :type temporal_mask: xr.DataArray :param grass_cut_numbers: Optional grass management descriptor used to align and merge grass-specific parameters. :type grass_cut_numbers: xr.DataArray :return: Yield depression in percent, clipped to ``[0, 100]``. :rtype: xr.DataArray """ if "time" in stress_data.dims: stress_data = ( stress_data.where(temporal_mask) .resample(time="YE") .cumsum("time") .reindex(time=stress_data.time) .pipe( lambda da: ( da if stress_data.chunks is None else da.chunk( time=stress_data.chunks[stress_data.dims.index("time")] ) ) ) ) if grass_cut_numbers is not None: def _merge_grasses(da): da.loc[{"crop": "grassM"}] = da.sel(crop="grassM").where( da.sel(crop="grassM"), da.sel(crop="grass2").where( da.sel(crop="grass2"), da.sel(crop="grass1") ), ) da = da.drop_sel(crop=["grass1", "grass2"]) return da.assign_coords( crop=[(x if x != "grassM" else "grassland") for x in da.crop.values] ) params = params.broadcast_like(grass_cut_numbers).copy().pipe(_merge_grasses) yield_reduction = stress_data * params # negative yield_expectation = yield_intercept + yield_reduction.sum("stressor") yield_depression = (100 * (1 - yield_expectation / yield_max)).clip(0, 100) return set_yield_meta( yield_depression.transpose(*stress_data.isel(stressor=0).dims) )
def _open_dataarray(path: str) -> xr.DataArray: path_obj = Path(path) if path_obj.is_dir() or path_obj.suffix == ".zarr": opened = xr.open_zarr(str(path_obj), decode_coords="all") else: try: return xr.open_dataarray(str(path_obj), decode_coords="all") except ValueError: opened = xr.open_dataset(str(path_obj), decode_coords="all") if isinstance(opened, xr.DataArray): return opened if len(opened.data_vars) == 1: return opened[next(iter(opened.data_vars))] raise ValueError( f"Could not infer a single DataArray from '{path_obj}'. " "Provide a path containing exactly one data variable." ) def _validate_yield_paths( *, mode: str, yield_max_path: str | None, yield_intercept_path: str | None, yield_params_path: str | None, ): if mode not in {"yield", "both"}: return missing = [ flag for flag, value in [ ("--yield-max", yield_max_path), ("--yield-intercept", yield_intercept_path), ("--yield-params", yield_params_path), ] if not value ] if missing: raise ValueError( "Yield mode requires explicit parameter paths. Missing flags: " + ", ".join(missing) ) def run_yield_depression( years: Iterable[int], *, yield_max_path: str, yield_intercept_path: str, yield_params_path: str, temporal_mask_path: str | None = None, grass_cut_numbers_path: str | None = None, base_dir: str = str(DEFAULT_BASE_DIR), ): """Compute and persist yearly yield_depression from stored stress inputs.""" yield_max = _open_dataarray(yield_max_path) yield_intercept = _open_dataarray(yield_intercept_path) params = _open_dataarray(yield_params_path) temporal_mask = ( _open_dataarray(temporal_mask_path) if temporal_mask_path is not None else xr.DataArray(True) ) grass_cut_numbers = ( _open_dataarray(grass_cut_numbers_path) if grass_cut_numbers_path is not None else None ) for year in years: yearly_output = output_year(year, base_dir=base_dir) yearly_stress = intermediate_stress(year, base_dir=base_dir) if yearly_output.exists(): print(f"! WARNING: {yearly_output} already exists. Skipping.") continue print("Calculating yield depression for year", year) csi = xr.open_zarr( str(yearly_stress), decode_coords="all" ).heat_drought_compound_stress stress_data = ( csi if "stressor" in csi.dims else csi.expand_dims(stressor=["heat_drought_compound_stress"]) ) yield_depression = calc_yield( stress_data=stress_data, yield_max=yield_max, yield_intercept=yield_intercept, params=params, temporal_mask=temporal_mask, grass_cut_numbers=grass_cut_numbers, ) yield_depression.to_dataset(name=yield_depression.name).drop_encoding().to_zarr( str(yearly_output), mode="a-", )
[docs] def run_calc_yield( years: Iterable[int], mode: str = "auto", workers: int = 4, mem_per_worker: str = "1Gb", base_dir: str = str(DEFAULT_BASE_DIR), yield_max_path: str | None = None, yield_intercept_path: str | None = None, yield_params_path: str | None = None, temporal_mask_path: str | None = None, grass_cut_numbers_path: str | None = None, ): """ Command-line interface for computing stress indices and/or yield expectations. Parses command-line arguments to determine which computations to perform and for which years. Initializes a Dask cluster for parallel processing, handles missing data, and manages workflow for stress and yield calculations. Usage: aris-calc-yield [-m MODE] [years ...] [--workers N] [--mem-per-worker SIZE] :return: None """ years = sorted(years) mode = mode.lower() if mode == "auto": if all( intermediate_stress(year, base_dir=base_dir).exists() for year in years ): print( "Stress index is present, assuming you want to have the yield " "depression computed." ) mode = "yield" else: print( "Stress index is missing for year(s):", ", ".join( [ str(year) for year in years if not intermediate_stress(year, base_dir=base_dir).exists() ] ) + ". Computing these first before estimating yield depression.", ) mode = "both" _validate_yield_paths( mode=mode, yield_max_path=yield_max_path, yield_intercept_path=yield_intercept_path, yield_params_path=yield_params_path, ) from dask.distributed import LocalCluster, Client print(f"Starting dask ({workers} CPUs, each {mem_per_worker} RAM)") client = Client( LocalCluster( n_workers=workers, memory_limit=mem_per_worker, death_timeout=30 ) ) print("... access the dashboard at", client.dashboard_link) try: if mode in {"stress", "both"}: run_calc_stresses(years=years, base_dir=base_dir) if mode in {"yield", "both"}: run_yield_depression( years=years, yield_max_path=str(yield_max_path), yield_intercept_path=str(yield_intercept_path), yield_params_path=str(yield_params_path), temporal_mask_path=temporal_mask_path, grass_cut_numbers_path=grass_cut_numbers_path, base_dir=base_dir, ) except (FileNotFoundError,) as err: if str(err).startswith("Unable to find group"): print( "\n! ERROR: data missing. Verify that the necessary data are " "available.\n" ) raise finally: client.close() print("Closed dask client\n") print(f"Successfully computed {mode} related variables!\n")
[docs] def main_heat_drought_compound_stress( years: Iterable[int], base_dir: str = str(DEFAULT_BASE_DIR), ): """Legacy alias for :func:`run_calc_stresses`.""" warn_legacy_python_api( "aris_lite.yield_expectation.main_heat_drought_compound_stress", "aris_lite.yield_expectation.run_calc_stresses", ) run_calc_stresses(years=years, base_dir=base_dir)
[docs] def main_yield( years: Iterable[int], *, yield_max_path: str | None = None, yield_intercept_path: str | None = None, yield_params_path: str | None = None, temporal_mask_path: str | None = None, grass_cut_numbers_path: str | None = None, base_dir: str = str(DEFAULT_BASE_DIR), ): """Legacy alias for :func:`run_yield_depression`.""" warn_legacy_python_api( "aris_lite.yield_expectation.main_yield", "aris_lite.yield_expectation.run_calc_yield", ) _validate_yield_paths( mode="yield", yield_max_path=yield_max_path, yield_intercept_path=yield_intercept_path, yield_params_path=yield_params_path, ) run_yield_depression( years=years, yield_max_path=str(yield_max_path), yield_intercept_path=str(yield_intercept_path), yield_params_path=str(yield_params_path), temporal_mask_path=temporal_mask_path, grass_cut_numbers_path=grass_cut_numbers_path, base_dir=base_dir, )
[docs] def main_cli(argv: list[str] | None = None) -> int: """Legacy CLI alias for `aris-calc-yield`; use `aris calc yield`.""" warn_legacy_python_api( "aris_lite.yield_expectation.main_cli", "aris_lite.cli.main:main_cli", ) from aris_lite.cli.legacy_wrappers import legacy_yield_cli return legacy_yield_cli(argv=argv, emit_warning=False)
if __name__ == "__main__": raise SystemExit(main_cli())