⚠️ Our 0.1 release refactored several early-development functions for long-term stability, to update your code see here. ⚠️
Future changes will come with deprecation warnings! 🙂

Source code for geoutils.raster.multiraster

"""Multiple rasters tools."""
from __future__ import annotations

import warnings
from typing import Any, Callable

import numpy as np
import rasterio as rio
import rasterio.warp
from tqdm import tqdm

import geoutils as gu
from geoutils._typing import NDArrayNum
from geoutils.misc import resampling_method_from_str
from geoutils.raster import Raster, RasterType, get_array_and_mask
from geoutils.raster.raster import _default_nodata


def load_multiple_rasters(
    raster_paths: list[str], crop: bool = True, ref_grid: int | None = None, **kwargs: Any
) -> list[RasterType]:
    """
    Function to load multiple rasters at once in a memory efficient way.
    First load metadata only.
    Optionally, crop all rasters to their intersection (default).
    Optionally, reproject all rasters to the grid of one raster set as reference (after optional crop).
    Otherwise, simply load the full rasters.

    :param raster_paths: List of paths to the rasters to be loaded
    :param crop: if set to True, will only load rasters in the area they intersect
    :param ref_grid: If set to an integer value, the raster with that index will be considered as the reference
    and all other rasters will be reprojected on the same grid (after optional crop)
    :param kwargs: optional arguments to be passed to Raster.reproject, e.g. the resampling method

    :returns: a list of loaded Raster instances
    """
    # If ref_grid is provided, need to reproject
    if isinstance(ref_grid, int):
        reproject = True
    # if no ref_grid provided, still need a reference CRS, use first by default
    elif ref_grid is None:
        ref_grid = 0
        reproject = False
    else:
        raise ValueError("`ref_grid` must be None or an integer")

    # Need to define a reference CRS for calculating intersection
    ref_crs = gu.Raster(raster_paths[ref_grid], load_data=False).crs

    # First load all rasters metadata
    output_rst = []
    bounds = []
    for path in raster_paths:
        # Initialize raster
        rst = gu.Raster(path, load_data=False)
        output_rst.append(rst)

        # Get bound in reference CRS
        bound = rst.get_bounds_projected(ref_crs)
        bounds.append(bound)

    # Second get the intersection of all raster bounds
    intersection = gu.projtools.merge_bounds(bounds, merging_algorithm="intersection")

    # Optionally, crop the rasters
    if crop:
        # Check that intersection is not void (changed to NaN instead of empty tuple end 2022)
        if intersection == () or all(np.isnan(i) for i in intersection):
            warnings.warn("Intersection is void, returning unloaded rasters.")
            return output_rst

        for rst in output_rst:
            # Calculate bounds in rst's CRS
            # rasterio's default for densify_pts is too low for very large images, set a default of 5000
            new_bounds = rio.warp.transform_bounds(
                ref_crs, rst.crs, intersection[0], intersection[1], intersection[2], intersection[3], densify_pts=5000
            )
            # Ensure bounds align with the original ones, to avoid resampling at this stage
            new_bounds = gu.projtools.align_bounds(rst.transform, new_bounds)
            rst.crop(new_bounds, mode="match_pixel", inplace=True)

    # Optionally, reproject all rasters to the reference grid
    if reproject:
        ref_rst = output_rst[ref_grid]

        # Set output bounds - intersection if crop is True, otherwise use that of ref_grid
        if crop:
            # make sure new bounds align with reference's bounds (to avoid resampling ref)
            new_bounds = intersection
            new_bounds = gu.projtools.align_bounds(ref_rst.transform, intersection)
            new_bounds = {"left": new_bounds[0], "bottom": new_bounds[1], "right": new_bounds[2], "top": new_bounds[3]}
        else:
            new_bounds = ref_rst.bounds

        # Reproject all rasters
        for index, rst in enumerate(output_rst):
            out_rst = rst.reproject(crs=ref_rst.crs, bounds=new_bounds, res=ref_rst.res, silent=True, **kwargs)
            if not out_rst.is_loaded:
                out_rst.load()
            output_rst[index] = out_rst

    # if no crop or reproject option, simply load the rasters
    if (not crop) & (not reproject):
        for rst in output_rst:
            rst.load()

    return output_rst


[docs] def stack_rasters( rasters: list[RasterType], reference: int | gu.Raster = 0, resampling_method: str | rio.enums.Resampling = "bilinear", use_ref_bounds: bool = False, diff: bool = False, progress: bool = True, ) -> gu.Raster: """ Stack a list of rasters on their maximum extent into a multi-band raster. The input rasters can have any transform or CRS, and will be reprojected to the reference raster's CRS and resolution. The output multi-band raster has an extent that is the union of all raster extents, except if `use_ref_bounds` is used, and the number of band equal to the number of input rasters. Use diff=True to return directly the difference to the reference raster. Note that all rasters will be loaded once in memory. The data is only loaded for reprojection then deleted to optimize memory usage. :param rasters: List of rasters to be stacked. :param reference: Index of reference raster in the list or separate reference raster. Defaults to the first raster in the list. :param resampling_method: Resampling method for reprojection. :param use_ref_bounds: If True, will use reference bounds, otherwise will use maximum bounds of all rasters. :param diff: If True, will return the difference to the reference raster. :param progress: If True, will display a progress bar. Default is True. :returns: The merged raster with same CRS and resolution (and optionally bounds) as the reference. """ # Check resampling method if isinstance(resampling_method, str): resampling_method = resampling_method_from_str(resampling_method) # Check raster has a single band if any(r.count > 1 for r in rasters): warnings.warn("Some input Rasters have multiple bands, only their first band will be used.") # Select reference raster if isinstance(reference, int): reference_raster = rasters[reference] elif isinstance(reference, gu.Raster): reference_raster = reference else: raise ValueError("reference should be either an integer or geoutils.Raster object") # Set output bounds if use_ref_bounds: dst_bounds = reference_raster.bounds else: dst_bounds = gu.projtools.merge_bounds( [raster.get_bounds_projected(out_crs=reference_raster.crs) for raster in rasters], resolution=reference_raster.res[0], return_rio_bbox=True, ) # Make a data list and add all of the reprojected rasters into it. data: list[NDArrayNum] = [] for raster in tqdm(rasters, disable=not progress): # Check that data is loaded, otherwise temporarily load it if not raster.is_loaded: raster.load() nodata = reference_raster.nodata or gu.raster.raster._default_nodata(reference_raster.data.dtype) # Reproject to reference grid reprojected_raster = raster.reproject( bounds=dst_bounds, res=reference_raster.res, crs=reference_raster.crs, dtype=reference_raster.data.dtype, nodata=reference_raster.nodata, silent=True, ) reprojected_raster.set_nodata(nodata) # Optionally calculate difference if diff: diff_to_ref = (reference_raster.data - reprojected_raster.data).squeeze() diff_to_ref, _ = get_array_and_mask(diff_to_ref) data.append(diff_to_ref) else: # img_data, _ = get_array_and_mask(reprojected_raster.data.squeeze()) # Use only first band if reprojected_raster.count == 1: data.append(reprojected_raster.data[:]) else: data.append(reprojected_raster.data[0, :]) # Remove unloaded rasters if not raster.is_loaded: raster._data = None # Convert to masked array data = np.ma.asarray(data) if reference_raster.nodata is not None: nodata = reference_raster.nodata else: nodata = _default_nodata(data.dtype) data[np.isnan(data)] = nodata # type: ignore # Save as gu.Raster - needed as some child classes may not accept multiple bands r = gu.Raster.from_array( data=data, transform=rio.transform.from_bounds(*dst_bounds, width=data[0].shape[1], height=data[0].shape[0]), crs=reference_raster.crs, nodata=nodata, ) return r
[docs] def merge_rasters( rasters: list[RasterType], reference: int | Raster = 0, merge_algorithm: Callable | list[Callable] = np.nanmean, # type: ignore resampling_method: str | rio.enums.Resampling = "bilinear", use_ref_bounds: bool = False, progress: bool = True, ) -> RasterType: """ Spatially merge a list of rasters into one larger raster of their maximum extent. The input rasters can have any transform or CRS, and will be reprojected to the reference raster's CRS and resolution. The output merged raster has an extent that is the union of all raster extents, except if `use_ref_bounds` is used. Note that all rasters will be loaded once in memory. The data is only loaded for reprojection then deleted to optimize memory usage. :param rasters: List of rasters to be merged. :param reference: Index of reference raster in the list or separate reference raster. Defaults to the first raster in the list. :param merge_algorithm: Reductor function (or list of functions) to merge the rasters with. Defaults to the mean. If several algorithms are provided, each result is returned as a separate band. :param resampling_method: Resampling method for reprojection. :param use_ref_bounds: If True, will use reference bounds, otherwise will use maximum bounds of all rasters. :param progress: If True, will display a progress bar. Default is True. :returns: The merged raster with same CRS and resolution (and optionally bounds) as the reference. """ # Make sure merge_algorithm is a list if not isinstance(merge_algorithm, (list, tuple)): merge_algorithm = [ merge_algorithm, ] # Try to run the merge_algorithm with an arbitrary list. Raise an error if the algorithm is incompatible. for algo in merge_algorithm: try: algo([1, 2]) except TypeError as exception: raise TypeError(f"merge_algorithm must be able to take a list as its first argument.\n\n{exception}") # Select reference raster if isinstance(reference, int): reference_raster = rasters[reference] elif isinstance(reference, gu.Raster): reference_raster = reference else: raise ValueError("reference should be either an integer or geoutils.Raster object") # Reproject and stack all rasters raster_stack = stack_rasters( rasters, reference=reference, resampling_method=resampling_method, use_ref_bounds=use_ref_bounds, progress=progress, ) # Try to use the keyword axis=0 for the merging algorithm (if it's a numpy ufunc). merged_data = [] for algo in merge_algorithm: try: merged_data.append(algo(raster_stack.data, axis=0)) # If that doesn't work, use the slower np.apply_along_axis approach. except TypeError as exception: if "'axis' is an invalid keyword" not in str(exception): raise exception merged_data.append(np.apply_along_axis(algo, axis=0, arr=raster_stack.data)) # Convert to masked array, and set all Nans to nodata merged_data = np.ma.asarray(merged_data) if reference_raster.nodata is not None: nodata = reference_raster.nodata else: nodata = _default_nodata(merged_data.dtype) merged_data[np.isnan(merged_data)] = nodata # Save as gu.Raster merged_raster = reference_raster.from_array( data=np.reshape(merged_data, (len(merged_data),) + merged_data[0].shape), transform=rio.transform.from_bounds( *raster_stack.bounds, width=merged_data[0].shape[1], height=merged_data[0].shape[0] ), crs=reference_raster.crs, nodata=nodata, ) return merged_raster