
👋🌎 🌲🔥 This is the raster module docstring

List all gdal available drivers: $ python -c "from osgeo import gdal;print(' '.join(sorted([gdal.GetDriver(i).GetDescription() for i in range(gdal.GetDriverCount())])))"

  2"""👋🌎 🌲🔥
  3This is the raster module docstring
  5List all gdal available drivers:
  6$ python -c "from osgeo import gdal;print('\n'.join(sorted([gdal.GetDriver(i).GetDescription() for i in range(gdal.GetDriverCount())])))"
  8__author__ = "Fernando Badilla"
  9__revision__ = "$Format:%H$"
 11import logging
 12from pathlib import Path
 13from typing import Any, Dict, List, Optional, Tuple, Union
 15import numpy as np
 16from osgeo import gdal, ogr
 17from qgis.core import QgsRasterLayer
 19from .utils import fprint, qgis2numpy_dtype
 21logger = logging.getLogger(__name__)
 24def id2xy(idx: int, w: int, h: int) -> tuple[int, int]:
 25    """Transform a pixel or cell index, into x,y coordinates.
 26    In GIS, the origin is at the top-left corner, read from left to right, top to bottom.  
 27    If your're used to matplotlib, the y-axis is inverted.  
 28    Also as numpy array, the index of the pixel is [y, x].
 30    Args:
 31        param idx: index of the pixel or cell (0,..,w*h-1)  
 32        param w: width of the image or grid  
 33        param h: height of the image or grid (not really used!)
 35    Returns:
 36        tuple: (x, y) coordinates of the pixel or cell  
 37    """  # fmt: skip
 38    return idx % w, idx // w
 41def xy2id(x: int, y: int, w: int) -> int:
 42    """Transform a x,y coordinates into a pixel or cell index.
 43    In GIS, the origin is at the top-left corner, read from left to right, top to bottom.  
 44    If your're used to matplotlib, the y-axis is inverted.  
 45    Also as numpy array, the index of the pixel is [y, x].
 47    Args:
 48        param x: width or horizontal coordinate of the pixel or cell  
 49        param y: height or vertical coordinate of the pixel or cell  
 50        param w: width of the image or grid  
 52    Returns:
 53        int: index of the pixel or cell (0,..,w\*h-1)
 54    """  # fmt: skip
 55    return y * w + x
 58def read_raster_band(filename: str, band: int = 1) -> tuple[np.ndarray, int, int]:
 59    """Read a raster file and return the data as a numpy array, along width and height.
 61    Args:
 62        param filename: name of the raster file  
 63        param band: band number to read (default 1)
 65    Returns:
 66        tuple: (data, width, height)
 68    Raises:
 69        FileNotFoundError: if the file is not found
 70    """  # fmt: skip
 71    dataset = gdal.Open(filename, gdal.GA_ReadOnly)
 72    if dataset is None:
 73        raise FileNotFoundError(filename)
 74    return dataset.GetRasterBand(band).ReadAsArray(), dataset.RasterXSize, dataset.RasterYSize
 77def read_raster(
 78    filename: str, band: int = 1, data: bool = True, info: bool = True
 79) -> tuple[Union[np.ndarray, None], Union[dict, None]]:
 80    """Read a raster file and return the data as a numpy array.
 81    Along raster info: transform, projection, raster count, raster width, raster height.
 83    Args:
 84        param filename: name of the raster file
 85        param band: band number to read (default 1)
 86        param data: if True, return the data as a numpy array (default True)
 87        param info: if True, return the raster info (default True)
 89    Return tuple: (data, info)
 90        data: numpy 2d array with the raster data
 91        info: dictionary with keys:
 92            - Transform: geotransform parameters
 93            - Projection: projection string
 94            - RasterCount: number of bands
 95            - RasterXSize: width of the raster
 96            - RasterYSize: height of the raster
 97            - DataType: data type of the raster
 98            - NoDataValue: no data value of the raster
 99            - Minimum: minimum value of the raster
100            - Maximum: maximum value of the raster
102    Raises:
103        FileNotFoundError: if the file is not found
104    """  # fmt: skip
105    dataset = gdal.Open(filename, gdal.GA_ReadOnly)
106    if dataset is None:
107        raise FileNotFoundError(filename)
108    raster_band = dataset.GetRasterBand(band)
109    data_output = raster_band.ReadAsArray() if data else None
111    if info:
112        rmin = raster_band.GetMinimum()
113        rmax = raster_band.GetMaximum()
114        if not rmin or not rmax:
115            (rmin, rmax) = raster_band.ComputeRasterMinMax(True)
117    info_output = (
118        {
119            "Transform": dataset.GetGeoTransform(),
120            "Projection": dataset.GetProjection(),
121            "RasterCount": dataset.RasterCount,
122            "RasterXSize": dataset.RasterXSize,
123            "RasterYSize": dataset.RasterYSize,
124            "DataType": gdal.GetDataTypeName(raster_band.DataType),
125            "NoDataValue": raster_band.GetNoDataValue(),
126            "Minimum": rmin,
127            "Maximum": rmax,
128        }
129        if info
130        else None
131    )
132    return data_output, info_output
135def get_geotransform(raster_filename: str) -> tuple[float, float, float, float, float, float]:
136    """ Get geotransform from raster file.
137    Args:
138        raster_filename (str):
140    Returns:
141        tuple: geotransform
142        GT[0] x-coordinate of the upper-left corner of the upper-left pixel.
143        GT[1] w-e pixel resolution / pixel width.
144        GT[2] row rotation (typically zero).
145        GT[3] y-coordinate of the upper-left corner of the upper-left pixel.
146        GT[4] column rotation (typically zero).
147        GT[5] n-s pixel resolution / pixel height (negative value for a north-up image).
149    reference:
150    """  # fmt: skip
151    dataset = gdal.Open(raster_filename, gdal.GA_ReadOnly)
152    if dataset is None:
153        raise Exception(f"Data set is None, could not open {raster_filename}")
154    return dataset.GetGeoTransform()
157def transform_coords_to_georef(x_pixel: int, y_line: int, GT: tuple) -> tuple[float, float]:
158    """ Transform pixel coordinates to georeferenced coordinates.
159    Args:
160        x_pixel (int): x pixel coordinate.
161        y_line (int): y pixel coordinate.
162        GT (tuple): geotransform, see get_geotransform(filename)
164    Returns:
165        tuple: x_geo, y_geo.
167    reference:
168    """  # fmt: skip
169    x_geo = GT[0] + x_pixel * GT[1] + y_line * GT[2]
170    y_geo = GT[3] + x_pixel * GT[4] + y_line * GT[5]
171    return x_geo, y_geo
174def transform_georef_to_coords(x_geo: int, y_geo: int, GT: tuple) -> tuple[float, float]:
175    """Inverse of transform_coords_to_georef.
177    import sympy
178    a, b, c, d, e, f, g, i, j, x, y = sympy.symbols('a, b, c, d, e, f, g, i, j, x, y', real=True)
179    sympy.linsolve([a+i*b+j*c - x,d+i*e+j*f-y],(i,j))
180    {((-a*f + c*d - c*y + f*x)/(b*f - c*e), (a*e - b*d + b*y - e*x)/(b*f - c*e))}
182    Args:
183        x_geo (int): x georeferenced coordinate.
184        y_geo (int): y georeferenced coordinate.
185        GT (tuple): geotransform, see get_geotransform(filename)
187    Returns:
188        tuple: x_pixel, y_line.
190    TODO Raises:
191        Exception: if x_pixel or y_line are not integer coordinates. by tolerance?
193    reference:
194    """
195    a, b, c, d, e, f = GT
196    x, y = x_geo, y_geo
197    i, j = (-a * f + c * d - c * y + f * x) / (b * f - c * e), (a * e - b * d + b * y - e * x) / (b * f - c * e)
198    # if i % 1 != 0 or j % 1 != 0:
199    #     raise Exception("Not integer coordinates!")
200    return int(i), int(j)
203def get_rlayer_info(layer: QgsRasterLayer):
204    """Get raster layer info: width, height, extent, crs, cellsize_x, cellsize_y, nodata list, number of bands.
206    Args:
207        layer (QgsRasterLayer): A raster layer
208    Returns:
209        dict: raster layer info
210    """
211    provider = layer.dataProvider()
212    ndv = []
213    for band in range(1, layer.bandCount() + 1):
214        ndv += [None]
215        if provider.sourceHasNoDataValue(band):
216            ndv[-1] = provider.sourceNoDataValue(band)
217    return {
218        "width": layer.width(),
219        "height": layer.height(),
220        "extent": layer.extent(),
221        "crs":,
222        "cellsize_x": layer.rasterUnitsPerPixelX(),
223        "cellsize_y": layer.rasterUnitsPerPixelY(),
224        "nodata": ndv,
225        "bands": layer.bandCount(),
226        "file": layer.publicSource(),
227    }
230def get_rlayer_data(layer: QgsRasterLayer):
231    """Get raster layer data (EVERY BAND) as numpy array; Also returns nodata value, width and height
232    The user should check the shape of the data to determine if it is a single band or multiband raster.
233    len(data.shape) == 2 for single band, len(data.shape) == 3 for multiband.
235    Args:
236        layer (QgsRasterLayer): A raster layer
238    Returns:
239        data (np.array): Raster data as numpy array
240        nodata (None | list): No data value
241        width (int): Raster width
242        height (int): Raster height
244    FIXME? can a multiband raster have different nodata values and/or data types for each band?
245    TODO: make a band list as input
246    """
247    provider = layer.dataProvider()
248    if layer.bandCount() == 1:
249        block = provider.block(1, layer.extent(), layer.width(), layer.height())
250        nodata = None
251        if block.hasNoDataValue():
252            nodata = block.noDataValue()
253        np_dtype = qgis2numpy_dtype(provider.dataType(1))
254        data = np.frombuffer(, dtype=np_dtype).reshape(layer.height(), layer.width())
255        # return data, nodata, np_dtype
256    else:
257        data = []
258        nodata = []
259        np_dtypel = []
260        for i in range(layer.bandCount()):
261            block = provider.block(i + 1, layer.extent(), layer.width(), layer.height())
262            nodata += [None]
263            if block.hasNoDataValue():
264                nodata[-1] = block.noDataValue()
265            np_dtypel += [qgis2numpy_dtype(provider.dataType(i + 1))]
266            data += [np.frombuffer(, dtype=np_dtypel[-1]).reshape(layer.height(), layer.width())]
267        # would different data types bug this next line?
268        data = np.array(data)
269        # return data, nodata, np_dtypl
270    return data
273def get_cell_sizeV2(filename: str, band: int = 1) -> tuple[float, float]:
274    # TODO: deprecate this function
275    _, info = read_raster(filename, band=band, data=False, info=True)
276    return info["RasterXSize"], info["RasterYSize"]
279def get_cell_size(raster: gdal.Dataset) -> tuple[float, float]:
280    """Get the cell size(s) of a raster.
283    Args:
284        raster (gdal.Dataset | str): The GDAL dataset or path to the raster.
286    Returns:
287        float | tuple[float, float]: The cell size(s) as a single float or a tuple (x, y).
288    """  # fmt: skip
289    if isinstance(raster, str):
290        ds = gdal.Open(raster, gdal.GA_ReadOnly)
291    elif isinstance(raster, gdal.Dataset):
292        ds = raster
293    else:
294        raise ValueError("Invalid input type for raster")
296    # Get the affine transformation parameters
297    affine = ds.GetGeoTransform()
299    if affine[1] != -affine[5]:
300        # If x and y cell sizes are not equal
301        cell_size = (affine[1], -affine[5])  # Return as a tuple
302    else:
303        cell_size = affine[1]  # Return as a single float
305    return cell_size
308def mask_raster(raster_ds: gdal.Dataset, band: int, polygons: list[ogr.Geometry]) -> np.ndarray:
309    """Mask a raster with polygons using GDAL.
311    Args:
312        raster_ds (gdal.Dataset): GDAL dataset of the raster.
313        band (int): Band index of the raster.
314        polygons (list[ogr.Geometry]): List of OGR geometries representing polygons for masking.
316    Returns:
317        np.array: Masked raster data as a NumPy array.
318    """  # fmt: skip
320    # Get the mask as a NumPy boolean array
321    mask_array = rasterize_polygons(polygons, raster_ds.RasterXSize, raster_ds.RasterYSize)
323    # Read the original raster data
324    original_data = band.ReadAsArray()  #  FIXME: wrong type hint : int has no attribute ReadAsArray
326    # Apply the mask
327    masked_data = np.where(mask_array, original_data, np.nan)
329    return masked_data
332def rasterize_polygons(polygons: list[ogr.Geometry], width: int, height: int) -> np.ndarray:
333    """Rasterize polygons to a boolean array.
335    Args:
336        polygons (list[ogr.Geometry]): List of OGR geometries representing polygons for rasterization.
337        geo_transform (tuple): GeoTransform parameters for the raster.
338        width (int): Width of the raster.
339        height (int): Height of the raster.
341    Returns:
342        mask_array (np.array): Rasterized mask as a boolean array.
343    """  # fmt: skip
345    mask_array = np.zeros((height, width), dtype=bool)
347    # Create an in-memory layer to hold the polygons
348    mem_driver = ogr.GetDriverByName("Memory")
349    mem_ds = mem_driver.CreateDataSource("memData")
350    mem_layer = mem_ds.CreateLayer("memLayer", srs=None, geom_type=ogr.wkbPolygon)
352    for geometry in polygons:
353        mem_feature = ogr.Feature(mem_layer.GetLayerDefn())
354        mem_feature.SetGeometry(geometry.Clone())
355        mem_layer.CreateFeature(mem_feature)
357    # Rasterize the in-memory layer and update the mask array
358    gdal.RasterizeLayer(mask_array, [1], mem_layer, burn_values=[1])
360    mem_ds = None  # Release the in-memory dataset
362    return mask_array
365def stack_rasters(
366    file_list: list[Path], mask_polygon: Union[list[ogr.Geometry], None] = None
367) -> tuple[np.ndarray, list[str]]:
368    """Stack raster files from a list into a 3D NumPy array.
370    Args:
371        file_list (list[Path]): List of paths to raster files.
372        mask_polygon (list[ogr.Geometry], optional): List of OGR geometries for masking. Defaults to None.
374    Returns:
375        np.array: Stacked raster array.
376        list: List of layer names corresponding to the rasters.
377    """  # fmt: skip
378    array_list = []
379    cell_sizes = set()
380    layer_names = []
382    for raster_path in file_list:
383        layer_name = raster_path.stem
384        layer_names.append(layer_name)
386        ds = gdal.Open(str(raster_path))
387        if ds is None:
388            raise ValueError(f"Failed to open raster file: {raster_path}")
390        band = ds.GetRasterBand(1)
392        if mask_polygon:
393            flatten_array = mask_raster(ds, band, mask_polygon)
394        else:
395            flatten_array = band.ReadAsArray()
397        array_list.append(flatten_array)
398        cell_sizes.add(get_cell_size(ds))
400    assert len(cell_sizes) == 1, f"There are rasters with different cell sizes: {cell_sizes}"
401    stacked_array = np.stack(array_list, axis=0)  #  type: np.array
402    print(stacked_array.shape)
403    return stacked_array, layer_names
406def write_raster(
407    data,
408    outfile="output.tif",
409    driver_name="GTiff",
410    authid="EPSG:3857",
411    geotransform=(0, 1, 0, 0, 0, -1),
412    nodata: int | None = None,
413    feedback=None,
414    logger=None,  # logger default ?
416    """Write a raster file from a numpy array.
418    To spatially match another raster, get authid and geotransform using:
419        from fire2a.raster import read_raster
420        _,info = read_raster(filename, data=False, info=True).
421        authid = info["Transform"]
422        geotransform = info["Projection"].
424    Args:
425        data (np.array): numpy array to write as raster
426        outfile (str, optional): output raster filename. Defaults to "output.tif".
427        driver_name (str, optional): GDAL driver name. Defaults to "GTiff".
428        authid (str, optional): EPSG code. Defaults to "EPSG:3857".
429        geotransform (tuple, optional): geotransform parameters. Defaults to (0, 1, 0, 0, 0, 1).
430        feedback (Optional, optional): object. Defaults to None.
431        logger ([type], optional): logging.logger object. Defaults to None.
432    Returns:
433        bool: True if the raster was written successfully, False otherwise.
434    """
436    try:
437        from fire2a.processing_utils import get_output_raster_format
439        driver_name = get_output_raster_format(outfile, feedback=feedback)
440    except Exception as e:
441        fprint(
442            f"Couln't get output raster format: {e}, defaulting to GTiff",
443            level="warning",
444            feedback=feedback,
445            logger=logger,
446        )
447        driver_name = "GTiff"
448    H, W = data.shape
449    ds = gdal.GetDriverByName(driver_name).Create(outfile, W, H, 1, gdal.GDT_Float32)
450    ds.SetGeoTransform(geotransform)
451    ds.SetProjection(authid)
452    band = ds.GetRasterBand(1)
453    if 0 != band.WriteArray(data):
454        fprint("WriteArray failed", level="warning", feedback=feedback, logger=logger)
455        return False
456    if nodata and data[data == nodata].size > 0:
457        band.SetNoDataValue(nodata)
458        # TBD : always returns 1?
459        # if 0 != band.SetNoDataValue(nodata):
460        #     fprint("Set NoData failed", level="warning", feedback=feedback, logger=logger)
461        #     return False
462    ds.FlushCache()
463    ds = None
464    return True
467if __name__ == "__main__":
468    file_list = list(Path().cwd().glob("*.asc"))
469    print(file_list)
470    array = stack_rasters(file_list)
471    print(array)
logger = <Logger fire2a.raster (WARNING)>
