fire2a.agglomerative_clustering
ππ π²π₯
Raster clustering
Usage
Overview
- Choose your raster files
- Configure nodata, scaling strategies and weights in the
config.toml
file - Choose "distance threshold" (or "number of clusters") for the Agglomerative clustering algorithm. Recommended:
- Start with a distance threshold of 10.0 and decrease for more or increase for less clusters
- After calibrating the distance threshold;
- Sieve small clusters (merge them to the biggest neighbor) with the
--sieve integer_pixels_size
option
Execution
# get command line help
python -m fire2a.agglomerative_clustering -h
python -m fire2a.agglomerative_clustering --help
# activate your qgis dev environment
source ~/pyqgisdev/bin/activate
# execute
(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0
# windowsπ© users should use QGIS's python
C:\PROGRA~1\QGIS33~1.3\bin\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0
More info on: How to windows π© using qgis's python
Preparation
1. Choose your raster files
- Any GDAL compatible raster will be read
- Place them all in the same directory where the script will be executed
- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]
2. Preprocessing configuration
See the config.toml
file for example of the configuration of the preprocessing steps. The file is structured as follows:
["filename.tif"]
no_data_strategy = "most_frequent"
scaling_strategy = "onehot"
fill_value = 0
weight = 1
__scaling_strategy__
__no_data_strategy__
- can be "mean", "median", "most_frequent", "constant"
- default is "mean"
- categorical data should use "most_frequent" or "constant"
- "constant" will use the value in __fill_value__ (see below)
- SimpleImputer
__fill_value__
- used when __no_data_strategy__ is "constant"
- default is 0
- SimpleImputer
__weight__
- default is 1
- used to give more importance to some features than others
- This is done after the nodata imputation and scaling steps, before clustering
3. Clustering configuration
- __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
-n
or--n_clusters
: The number of clusters to form as well as the number of centroids to generate.-d
or--distance_threshold
: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).- More parameters for clustering can be passed directly into the pipelie method as keyword arguments
- __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the GDAL sieve library
--sieve
: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor
4. Post-processing
Outputs can be:
- A raster file with the cluster labels and a polygon file with the cluster polygons
- A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster
- A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows)
Or use the --script
option to return the label_map and the pipeline object for further processing in another python script:
from fire2a.agglomerative_clustering import main
label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"])
1#!/usr/bin/env python3 2# fmt: off 3"""ππ π²π₯ 4# Raster clustering 5## Usage 6### Overview 71. Choose your raster files 82. Configure nodata, scaling strategies and weights in the `config.toml` file 93. Choose "distance threshold" (or "number of clusters") for the [Agglomerative](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) clustering algorithm. Recommended: 10 - Start with a distance threshold of 10.0 and decrease for more or increase for less clusters 11 - After calibrating the distance threshold; 12 - [Sieve](https://gdal.org/en/latest/programs/gdal_sieve.html) small clusters (merge them to the biggest neighbor) with the `--sieve integer_pixels_size` option 13 14### Execution 15```bash 16# get command line help 17python -m fire2a.agglomerative_clustering -h 18python -m fire2a.agglomerative_clustering --help 19 20# activate your qgis dev environment 21source ~/pyqgisdev/bin/activate 22# execute 23(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0 24 25# windowsπ© users should use QGIS's python 26C:\\PROGRA~1\\QGIS33~1.3\\bin\\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0 27``` 28[More info on: How to windows π© using qgis's python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers) 29 30### Preparation 31#### 1. Choose your raster files 32- Any [GDAL compatible](https://gdal.org/en/latest/drivers/raster/index.html) raster will be read 33- Place them all in the same directory where the script will be executed 34- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9] 35 36#### 2. Preprocessing configuration 37See the `config.toml` file for example of the configuration of the preprocessing steps. The file is structured as follows: 38 39```toml 40["filename.tif"] 41no_data_strategy = "most_frequent" 42scaling_strategy = "onehot" 43fill_value = 0 44weight = 1 45``` 46 471. __scaling_strategy__ 48 - can be "standard", "robust", "onehot" 49 - default is "robust" 50 - [Standard](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html): (x-mean)/stddev 51 - [Robust](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html): same but droping the tails of the distribution 52 - [OneHot](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html): __for CATEGORICAL DATA__ 53 542. __no_data_strategy__ 55 - can be "mean", "median", "most_frequent", "constant" 56 - default is "mean" 57 - categorical data should use "most_frequent" or "constant" 58 - "constant" will use the value in __fill_value__ (see below) 59 - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html) 60 613. __fill_value__ 62 - used when __no_data_strategy__ is "constant" 63 - default is 0 64 - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html) 65 664. __weight__ 67 - default is 1 68 - used to give more importance to some features than others 69 - This is done after the nodata imputation and scaling steps, before clustering 70 71 72#### 3. Clustering configuration 73 741. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive: 75- `-n` or `--n_clusters`: The number of clusters to form as well as the number of centroids to generate. 76- `-d` or `--distance_threshold`: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm). 77- More [parameters](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) for clustering can be passed directly into the pipelie method as keyword arguments 78 792. __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the [GDAL sieve library](https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve) 80- `--sieve`: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor 81 82#### 4. Post-processing 83Outputs can be: 84- A raster file with the cluster labels and a polygon file with the cluster polygons 85- A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster 86- A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows) 87 88Or use the `--script` option to return the label_map and the pipeline object for further processing in another python script: 89 ```python 90 from fire2a.agglomerative_clustering import main 91 label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"]) 92 ``` 93""" 94# fmt: on 95# from IPython.terminal.embed import InteractiveShellEmbed 96# InteractiveShellEmbed()() 97import logging 98import sys 99from pathlib import Path 100 101import numpy as np 102from sklearn.base import BaseEstimator, TransformerMixin 103from sklearn.cluster import AgglomerativeClustering 104from sklearn.compose import ColumnTransformer 105from sklearn.impute import SimpleImputer 106from sklearn.neighbors import radius_neighbors_graph 107from sklearn.pipeline import Pipeline 108from sklearn.preprocessing import OneHotEncoder, RobustScaler, StandardScaler 109 110from fire2a.utils import fprint, read_toml 111 112logger = logging.getLogger(__name__) 113 114 115def check_shapes(data_list): 116 """Check if all data arrays have the same shape and are 2D. 117 Returns the shape of the data arrays if they are all equal. 118 """ 119 from functools import reduce 120 121 def equal_or_error(x, y): 122 """Check if x and y are equal, returns x if equal else raises a ValueError.""" 123 if x == y: 124 return x 125 else: 126 raise ValueError("All data arrays must have the same shape") 127 128 shape = reduce(equal_or_error, (data.shape for data in data_list)) 129 if len(shape) != 2: 130 raise ValueError("All data arrays must be 2D") 131 height, width = shape 132 return height, width 133 134 135def get_map_neighbors(height, width, num_neighbors=8): 136 """Get the neighbors of each cell in a 2D grid. 137 n_jobs=-1 uses all available cores. 138 """ 139 140 grid_points = np.indices((height, width)).reshape(2, -1).T 141 142 nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1) 143 nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1) 144 145 # assert nb4.shape[0] == width * height 146 # assert nb8.shape[1] == width * height 147 # for n in range(width * height): 148 # _, neighbors = np.nonzero(nb4[n]) 149 # assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}" 150 # assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}" 151 return nb4, nb8 152 153 154class NoDataImputer(BaseEstimator, TransformerMixin): 155 """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column""" 156 157 def __init__(self, no_data_values, strategies, constants): 158 self.no_data_values = no_data_values 159 self.strategies = strategies 160 self.constants = constants 161 self.imputers = [] 162 for no_data_value, strategy, constant in zip(no_data_values, strategies, constants): 163 if no_data_value: 164 self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)] 165 else: 166 self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)] 167 168 def fit(self, X, y=None): 169 for i, imputer in enumerate(self.imputers): 170 imputer.fit(X[:, [i]], y) 171 return self 172 173 def transform(self, X): 174 for i, imputer in enumerate(self.imputers): 175 X[:, [i]] = imputer.transform(X[:, [i]]) 176 self.output_data = X 177 return X 178 179 180class RescaleAllToCommonRange(BaseEstimator, TransformerMixin): 181 """A custom transformer that rescales all features to a common range [0, 1]""" 182 183 def __init__(self, weight_map): 184 self.weight_map = weight_map 185 186 def fit(self, X, y=None): 187 # Determine the combined range of all scaled features 188 self.min_val = [x.min() for x in X.T] 189 self.max_val = [x.max() for x in X.T] 190 return self 191 192 def transform(self, X): 193 # Rescale all features to match the common range 194 for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)): 195 if ma - mi == 0: 196 X.T[i] = x * self.weight_map[i] 197 else: 198 X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i] 199 return X 200 201 202class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin): 203 def __init__(self, height, width, neighbors=4, **kwargs): 204 self.height = height 205 self.width = width 206 self.neighbors = neighbors 207 208 self.grid_points = np.indices((height, width)).reshape(2, -1).T 209 if neighbors == 4: 210 connectivity = radius_neighbors_graph( 211 self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1 212 ) 213 elif neighbors == 8: 214 connectivity = radius_neighbors_graph( 215 self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1 216 ) 217 218 self.connectivity = connectivity 219 self.kwargs = kwargs 220 self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs) 221 222 def fit(self, X, y=None): 223 logger.debug("not sure why, but this method is never called alas needed") 224 self.model.fit(X) 225 return self 226 227 def fit_predict(self, X, y=None): 228 self.input_data = X 229 return self.model.fit_predict(X) 230 231 232def pipelie(observations, info_list, height, width, **kwargs): 233 """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix 234 Steps are: 235 1. Impute missing values 236 2. Scale the features 237 3. Rescale all features to a common range 238 4. Cluster the data using Agglomerative Clustering with connectivity 239 5. Reshape the labels back to the original spatial map shape 240 6. Return the labels and the pipeline object 241 242 Args: 243 observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped 244 info_list (list): A list of dictionaries containing information about each feature 245 height (int): The height of the spatial map 246 width (int): The width of the spatial map 247 kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold 248 249 Returns: 250 np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape 251 Pipeline: The pipeline object containing all the steps of the pipeline 252 """ 253 # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold} 254 255 # imputer strategies 256 no_data_values = [info["NoDataValue"] for info in info_list] 257 no_data_strategies = [info["no_data_strategy"] for info in info_list] 258 fill_values = [info["fill_value"] for info in info_list] 259 weights = [info["weight"] for info in info_list] 260 # scaling_strategies = [info["scaling_strategy"] for info in info_list] 261 262 # scaling strategies 263 index_map = {} 264 for strategy in ["robust", "standard", "onehot"]: 265 index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy] 266 # index_map 267 # !cat config.toml 268 269 # Create transformers for each type 270 robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())]) 271 standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())]) 272 onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))]) 273 # OneHotEncoder._n_features_outs): 274 275 # Combine transformers using ColumnTransformer 276 feature_scaler = ColumnTransformer( 277 transformers=[ 278 ("robust", robust_transformer, index_map["robust"]), 279 ("standard", standard_transformer, index_map["standard"]), 280 ("onehot", onehot_transformer, index_map["onehot"]), 281 ] 282 ) 283 284 # # Create a temporary directory for caching calculations 285 # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS 286 # import tempfile 287 # import joblib 288 # temp_dir = tempfile.mkdtemp() 289 # memory = joblib.Memory(location=temp_dir, verbose=0) 290 291 # Create and apply the pipeline 292 # part 1 until feature scaling 293 pipe1 = Pipeline( 294 # n_features_in_ : int 295 # feature_names_in_ : ndarray of shape (`n_features_in_`,) 296 steps=[ 297 ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)), 298 ("feature_scaling", feature_scaler), 299 ], 300 # memory=memory, 301 verbose=True, 302 ) 303 # map weights to new columns (onehot feature scaler creates one column per category) 304 obs1 = pipe1.fit_transform(observations) 305 cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out() 306 split_names = [name.split("_")[0] for name in cat_names] 307 cat_count = np.unique(split_names, return_counts=True)[1] 308 onehot_map = {} 309 for i, key in enumerate(index_map["onehot"]): 310 onehot_map[key] = cat_count[i] 311 # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])} 312 weight_map = [] 313 for name, idxs in index_map.items(): 314 for idx in idxs: 315 if name == "onehot": 316 weight_map += [weights[idx]] * onehot_map[idx] 317 continue 318 weight_map += [weights[idx]] 319 # part 2 use weight_map and cluster 320 pipe2 = Pipeline( 321 steps=[ 322 ("common_rescaling", RescaleAllToCommonRange(weight_map)), 323 ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)), 324 ], 325 # memory=memory, 326 verbose=True, 327 ) 328 329 # apply pipeLIE 330 labels = pipe2.fit_predict(obs1) 331 332 # Reshape the labels back to the original spatial map shape 333 labels_reshaped = labels.reshape(height, width) 334 return labels_reshaped, pipe1, pipe2 335 336 337def write( 338 label_map, 339 width, 340 height, 341 output_raster="", 342 output_poly="output.shp", 343 authid="EPSG:3857", 344 geotransform=(0, 1, 0, 0, 0, 1), 345 nodata=None, 346 feedback=None, 347): 348 from osgeo import gdal, ogr, osr 349 350 from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename 351 352 # setup drivers for raster and polygon output formats 353 if output_raster == "": 354 raster_driver = "MEM" 355 else: 356 try: 357 raster_driver = get_output_raster_format(output_raster, feedback=feedback) 358 except Exception: 359 raster_driver = "GTiff" 360 try: 361 poly_driver = get_vector_driver_from_filename(output_poly) 362 except Exception: 363 poly_driver = "ESRI Shapefile" 364 365 # create raster output 366 src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64) 367 src_ds.SetGeoTransform(geotransform) # != 0 ? 368 src_ds.SetProjection(authid) # != 0 ? 369 # src_band = src_ds.GetRasterBand(1) 370 # if nodata: 371 # src_band.SetNoDataValue(nodata) 372 # src_band.WriteArray(label_map) 373 374 # create polygon output 375 drv = ogr.GetDriverByName(poly_driver) 376 dst_ds = drv.CreateDataSource(output_poly) 377 sp_ref = osr.SpatialReference() 378 sp_ref.SetFromUserInput(authid) # != 0 ? 379 dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon) 380 dst_lyr.CreateField(ogr.FieldDefn("DN", ogr.OFTInteger64)) # != 0 ? 381 dst_lyr.CreateField(ogr.FieldDefn("pixel_count", ogr.OFTInteger64)) 382 # dst_lyr.CreateField(ogr.FieldDefn("area", ogr.OFTInteger)) 383 # dst_lyr.CreateField(ogr.FieldDefn("perimeter", ogr.OFTInteger)) 384 385 # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress) 386 387 # FAIL: All together it merges labels into a single polygon 388 # src_band = src_ds.GetRasterBand(1) 389 # if nodata: 390 # src_band.SetNoDataValue(nodata) 391 # src_band.WriteArray(label_map) 392 # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress) # , ["8CONNECTED=8"]) 393 394 # B separado 395 # for loop for creating each label_map value into a different polygonized feature 396 mem_drv = ogr.GetDriverByName("Memory") 397 tmp_ds = mem_drv.CreateDataSource("tmp_ds") 398 # itera = iter(np.unique(label_map)) 399 # cluster_id = next(itera) 400 areas = [] 401 pixels = [] 402 data = np.zeros_like(label_map) 403 for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)): 404 # temporarily write band 405 src_band = src_ds.GetRasterBand(1) 406 src_band.SetNoDataValue(0) 407 data[label_map == cluster_id] = 1 408 src_band.WriteArray(data) 409 # create feature 410 tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref) 411 gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1) 412 # unset tmp data 413 data[label_map == cluster_id] = 0 414 # set polygon feat 415 feat = tmp_lyr.GetNextFeature() 416 geom = feat.GetGeometryRef() 417 featureDefn = dst_lyr.GetLayerDefn() 418 feature = ogr.Feature(featureDefn) 419 feature.SetGeometry(geom) 420 feature.SetField("DN", float(cluster_id)) 421 areas += [geom.GetArea()] 422 pixels += [px_count] 423 feature.SetField("pixel_count", float(px_count)) 424 # feature.SetField("area", int(geom.GetArea())) 425 # feature.SetField("perimeter", int(geom.Boundary().Length())) 426 dst_lyr.CreateFeature(feature) 427 428 fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger) 429 fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger) 430 # RESTART RASTER 431 # src_ds = None 432 # src_band = None 433 # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64) 434 # src_ds.SetGeoTransform(geotransform) # != 0 ? 435 # src_ds.SetProjection(authid) # != 0 ? 436 src_band = src_ds.GetRasterBand(1) 437 if nodata: 438 src_band.SetNoDataValue(nodata) 439 else: 440 # useless paranoia ? 441 src_band.SetNoDataValue(-1) 442 src_band.WriteArray(label_map) 443 # close datasets 444 src_ds.FlushCache() 445 src_ds = None 446 dst_ds.FlushCache() 447 dst_ds = None 448 return True 449 450 451def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs): 452 """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. 453 Args: 454 labels_reshaped (np.ndarray): The reshaped labels of the clusters 455 pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps 456 pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps 457 info_list (list): A list of dictionaries containing information about each feature 458 **kargs: Additional keyword arguments 459 n_clusters (int): The number of clusters 460 distance_threshold (float): The linkage distance threshold 461 sieve (int): The number of pixels to use as a sieve filter 462 block (bool): Block the execution until the plot window is closed 463 filename (str): The filename to save the plot 464 """ 465 from matplotlib import pyplot as plt 466 467 no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data 468 pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data 469 470 # filtrar onehot 471 num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"]) 472 num_no_onehots = len(info_list) - num_onehots 473 pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots] 474 475 # indices sin onehots 476 nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"] 477 478 # filtrar onehot de no_data_treated 479 no_data_imputed = no_data_imputed[:, nohots_idxs] 480 481 # reordenados en robust y despues standard 482 rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"] 483 rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"] 484 485 # reordenar rob then std 486 pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs] 487 488 names = [info_list[i]["fname"] for i in nohots_idxs] 489 490 fgs = np.array(plt.rcParams["figure.figsize"]) * 5 491 fig, axs = plt.subplots(3, 2, figsize=fgs) 492 suptitle = "" 493 if n_clusters := kwargs.get("n_clusters"): 494 suptitle = f"n_clusters: {n_clusters}" 495 if distance_threshold := kwargs.get("distance_threshold"): 496 suptitle = f"distance_threshold: {distance_threshold}" 497 if sieve := kwargs.get("sieve"): 498 suptitle += f", sieve: {sieve}" 499 if n_clusters or distance_threshold or sieve: 500 suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}" 501 suptitle += "\n(Not showing categorical data)" 502 fig.suptitle(suptitle) 503 504 # plot violin plot 505 axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True) 506 axs[0, 0].set_title("Violin Plot of NoData Imputed") 507 axs[0, 0].yaxis.grid(True) 508 axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 509 axs[0, 0].set_ylabel("Observed values") 510 511 # plot boxplot 512 axs[0, 1].boxplot(no_data_imputed) 513 axs[0, 1].set_title("Box Plot of NoData Imputed") 514 axs[0, 1].yaxis.grid(True) 515 axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 516 axs[0, 1].set_ylabel("Observed values") 517 518 # plot violin plot 519 axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True) 520 axs[1, 0].set_title("Violin Plot of Common Rescaled") 521 axs[1, 0].yaxis.grid(True) 522 axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 523 axs[0, 1].set_ylabel("Adjusted range") 524 525 # plot boxplot 526 axs[1, 1].boxplot(pre_aggclu_data) 527 axs[1, 1].set_title("Box Plot of Common Rescaled") 528 axs[1, 1].yaxis.grid(True) 529 axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 530 axs[0, 1].set_ylabel("Adjusted range") 531 532 # cluster history 533 unique_labels, counts = np.unique(labels_reshaped, return_counts=True) 534 axs[2, 0].plot(unique_labels, counts, marker="o", color="blue") 535 axs[2, 0].set_title("Cluster history size (in pixels)") 536 axs[2, 0].set_xlabel("Algorithm Step") 537 axs[2, 0].set_ylabel("Size (in pixels)") 538 539 # cluster histogram 540 axs[2, 1].hist(counts, log=True) 541 axs[2, 1].set_xlabel("Cluster Size (in pixels)") 542 axs[2, 1].set_ylabel("Number of Clusters") 543 axs[2, 1].set_title("Histogram of Cluster Sizes") 544 545 plt.tight_layout() 546 if filename := kwargs.get("filename"): 547 logger.info(f"Saving plot to {filename}") 548 plt.savefig(filename) 549 else: 550 if block := kwargs.get("block"): 551 plt.show(block=block) 552 else: 553 plt.show() 554 555 556def sieve_filter(data, threshold=2, connectedness=4, feedback=None): 557 """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve 558 Args: 559 data (np.ndarray): The input data to filter 560 threshold (int): The maximum number of pixels in a cluster to keep 561 connectedness (int): The number of connected pixels to consider when filtering 4 or 8 562 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins 563 Returns: 564 np.ndarray: The filtered data 565 """ 566 logger.info("Applying sieve filter") 567 from osgeo import gdal 568 569 height, width = data.shape 570 # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger) 571 num_clusters = len(np.unique(data)) 572 src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, gdal.GDT_Int64) 573 src_band = src_ds.GetRasterBand(1) 574 src_band.WriteArray(data) 575 if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness): 576 fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger) 577 else: 578 sieved = src_band.ReadAsArray() 579 src_band = None 580 src_ds = None 581 num_sieved = len(np.unique(sieved)) 582 # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger) 583 fprint( 584 f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less", 585 level="info", 586 feedback=feedback, 587 logger=logger, 588 ) 589 fprint( 590 "Please try again increasing distance_threshold or reducing n_clusters instead...", 591 level="info", 592 feedback=feedback, 593 logger=logger, 594 ) 595 # from matplotlib import pyplot as plt 596 # fig, (ax1, ax2) = plt.subplots(1, 2) 597 # ax1.imshow(data) 598 # ax1.set_title("before sieve" + str(len(np.unique(data)))) 599 # ax2.imshow(sieved) 600 # ax2.set_title("after sieve" + str(len(np.unique(sieved)))) 601 # plt.show() 602 # data = sieved 603 return sieved 604 605 606def arg_parser(argv=None): 607 """Parse command line arguments.""" 608 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 609 610 parser = ArgumentParser( 611 description="Agglomerative Clustering with Connectivity for raster data", 612 formatter_class=ArgumentDefaultsHelpFormatter, 613 epilog="More at https://fire2a.github.io/fire2a-lib", 614 ) 615 parser.add_argument( 616 "config_file", 617 nargs="?", 618 type=Path, 619 help="For each raster file, configure its preprocess: nodata & scaling methods", 620 default="config.toml", 621 ) 622 623 aggclu = parser.add_mutually_exclusive_group(required=True) 624 aggclu.add_argument( 625 "-d", 626 "--distance_threshold", 627 type=float, 628 help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)", 629 ) 630 aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters") 631 632 parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="") 633 parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg") 634 parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857") 635 parser.add_argument( 636 "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)" 637 ) 638 parser.add_argument( 639 "-nw", 640 "--no_write", 641 action="store_true", 642 help="Do not write outputs raster nor polygons", 643 default=False, 644 ) 645 parser.add_argument( 646 "-s", 647 "--script", 648 action="store_true", 649 help="Run in script mode, returning the label_map and the pipeline object", 650 default=False, 651 ) 652 parser.add_argument( 653 "--sieve", 654 type=int, 655 help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor", 656 ) 657 parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3") 658 659 plot = parser.add_argument_group( 660 "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms" 661 ) 662 plot.add_argument( 663 "-p", 664 "--plots", 665 action="store_true", 666 help="Activate the plotting routines", 667 ) 668 plot.add_argument( 669 "-b", 670 "--block", 671 action="store_false", 672 default=True, 673 help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS", 674 ) 675 plot.add_argument( 676 "-f", 677 "--filename", 678 type=str, 679 help="Filename to save the plot. If not provided, matplotlib will raise a window", 680 ) 681 args = parser.parse_args(argv) 682 args.geotransform = tuple(map(float, args.geotransform[1:-1].split(","))) 683 if Path(args.config_file).is_file() is False: 684 parser.error(f"File {args.config_file} not found") 685 return args 686 687 688def main(argv=None): 689 """ 690 691 args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) 692 args = arg_parser(["-d","10.0"]]) 693 args = arg_parser(["-d","10.0", "config2.toml"]) 694 args = arg_parser(["-n","10"]) 695 """ 696 if argv is sys.argv: 697 argv = sys.argv[1:] 698 args = arg_parser(argv) 699 700 if args.verbose != 0: 701 global logger 702 from fire2a import setup_logger 703 704 logger = setup_logger(verbosity=args.verbose) 705 706 logger.info("args %s", args) 707 708 # 2 LEE CONFIG 709 config = read_toml(args.config_file) 710 # logger.debug(config) 711 712 # 2.1 ADD DEFAULTS 713 for filename, file_config in config.items(): 714 if "no_data_strategy" not in file_config: 715 config[filename]["no_data_strategy"] = "mean" 716 if "scaling_strategy" not in file_config: 717 config[filename]["scaling_strategy"] = "robust" 718 if "fill_value" not in file_config: 719 config[filename]["fill_value"] = 0 720 if "weight" not in file_config: 721 config[filename]["weight"] = 1 722 logger.debug(config) 723 724 # 3. LEE DATA 725 from fire2a.raster import read_raster 726 727 data_list, info_list = [], [] 728 for filename, file_config in config.items(): 729 data, info = read_raster(filename) 730 info["fname"] = Path(filename).name 731 info["no_data_strategy"] = file_config["no_data_strategy"] 732 info["scaling_strategy"] = file_config["scaling_strategy"] 733 info["fill_value"] = file_config["fill_value"] 734 info["weight"] = file_config["weight"] 735 data_list += [data] 736 info_list += [info] 737 logger.debug("%s", data[:2, :2]) 738 logger.debug("%s", info) 739 740 # 4. VALIDAR 2d todos mismo shape 741 height, width = check_shapes(data_list) 742 743 # 5. lista[mapas] -> OBSERVACIONES 744 observations = np.column_stack([data.ravel() for data in data_list]) 745 746 # 6. nodata -> feature scaling -> all scaling -> clustering 747 labels_reshaped, pipe1, pipe2 = pipelie( 748 observations, 749 info_list, 750 height, 751 width, 752 n_clusters=args.n_clusters, 753 distance_threshold=args.distance_threshold, 754 ) # insert more keyworded arguments to the clustering algorithm here! 755 756 # SIEVE 757 if args.sieve: 758 logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}") 759 labels_reshaped = sieve_filter(labels_reshaped, args.sieve) 760 761 logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}") 762 763 # 7 debbuging plots 764 if args.plots: 765 plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args)) 766 767 # 8. ESCRIBIR RASTER 768 if not args.no_write: 769 if not write( 770 labels_reshaped, 771 width, 772 height, 773 output_raster=args.output_raster, 774 output_poly=args.output_poly, 775 authid=args.authid, 776 geotransform=args.geotransform, 777 ): 778 logger.error("Error writing output raster") 779 780 # 9. SCRIPT MODE 781 if args.script: 782 return labels_reshaped, pipe1, pipe2 783 784 return 0 785 786 787if __name__ == "__main__": 788 sys.exit(main(sys.argv))
116def check_shapes(data_list): 117 """Check if all data arrays have the same shape and are 2D. 118 Returns the shape of the data arrays if they are all equal. 119 """ 120 from functools import reduce 121 122 def equal_or_error(x, y): 123 """Check if x and y are equal, returns x if equal else raises a ValueError.""" 124 if x == y: 125 return x 126 else: 127 raise ValueError("All data arrays must have the same shape") 128 129 shape = reduce(equal_or_error, (data.shape for data in data_list)) 130 if len(shape) != 2: 131 raise ValueError("All data arrays must be 2D") 132 height, width = shape 133 return height, width
Check if all data arrays have the same shape and are 2D. Returns the shape of the data arrays if they are all equal.
136def get_map_neighbors(height, width, num_neighbors=8): 137 """Get the neighbors of each cell in a 2D grid. 138 n_jobs=-1 uses all available cores. 139 """ 140 141 grid_points = np.indices((height, width)).reshape(2, -1).T 142 143 nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1) 144 nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1) 145 146 # assert nb4.shape[0] == width * height 147 # assert nb8.shape[1] == width * height 148 # for n in range(width * height): 149 # _, neighbors = np.nonzero(nb4[n]) 150 # assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}" 151 # assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}" 152 return nb4, nb8
Get the neighbors of each cell in a 2D grid. n_jobs=-1 uses all available cores.
155class NoDataImputer(BaseEstimator, TransformerMixin): 156 """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column""" 157 158 def __init__(self, no_data_values, strategies, constants): 159 self.no_data_values = no_data_values 160 self.strategies = strategies 161 self.constants = constants 162 self.imputers = [] 163 for no_data_value, strategy, constant in zip(no_data_values, strategies, constants): 164 if no_data_value: 165 self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)] 166 else: 167 self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)] 168 169 def fit(self, X, y=None): 170 for i, imputer in enumerate(self.imputers): 171 imputer.fit(X[:, [i]], y) 172 return self 173 174 def transform(self, X): 175 for i, imputer in enumerate(self.imputers): 176 X[:, [i]] = imputer.transform(X[:, [i]]) 177 self.output_data = X 178 return X
A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column
158 def __init__(self, no_data_values, strategies, constants): 159 self.no_data_values = no_data_values 160 self.strategies = strategies 161 self.constants = constants 162 self.imputers = [] 163 for no_data_value, strategy, constant in zip(no_data_values, strategies, constants): 164 if no_data_value: 165 self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)] 166 else: 167 self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
181class RescaleAllToCommonRange(BaseEstimator, TransformerMixin): 182 """A custom transformer that rescales all features to a common range [0, 1]""" 183 184 def __init__(self, weight_map): 185 self.weight_map = weight_map 186 187 def fit(self, X, y=None): 188 # Determine the combined range of all scaled features 189 self.min_val = [x.min() for x in X.T] 190 self.max_val = [x.max() for x in X.T] 191 return self 192 193 def transform(self, X): 194 # Rescale all features to match the common range 195 for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)): 196 if ma - mi == 0: 197 X.T[i] = x * self.weight_map[i] 198 else: 199 X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i] 200 return X
A custom transformer that rescales all features to a common range [0, 1]
203class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin): 204 def __init__(self, height, width, neighbors=4, **kwargs): 205 self.height = height 206 self.width = width 207 self.neighbors = neighbors 208 209 self.grid_points = np.indices((height, width)).reshape(2, -1).T 210 if neighbors == 4: 211 connectivity = radius_neighbors_graph( 212 self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1 213 ) 214 elif neighbors == 8: 215 connectivity = radius_neighbors_graph( 216 self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1 217 ) 218 219 self.connectivity = connectivity 220 self.kwargs = kwargs 221 self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs) 222 223 def fit(self, X, y=None): 224 logger.debug("not sure why, but this method is never called alas needed") 225 self.model.fit(X) 226 return self 227 228 def fit_predict(self, X, y=None): 229 self.input_data = X 230 return self.model.fit_predict(X)
Base class for all estimators in scikit-learn.
Inheriting from this class provides default implementations of:
- setting and getting parameters used by
GridSearchCV
and friends; - textual and HTML representation displayed in terminals and IDEs;
- estimator serialization;
- parameters validation;
- data validation;
- feature names validation.
Read more in the :ref:User Guide <rolling_your_own_estimator>
.
Notes
All estimators should specify all the parameters that can be set
at the class level in their __init__
as explicit keyword
arguments (no *args
or **kwargs
).
Examples
>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
... def __init__(self, *, param=1):
... self.param = param
... def fit(self, X, y=None):
... self.is_fitted_ = True
... return self
... def predict(self, X):
... return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])
204 def __init__(self, height, width, neighbors=4, **kwargs): 205 self.height = height 206 self.width = width 207 self.neighbors = neighbors 208 209 self.grid_points = np.indices((height, width)).reshape(2, -1).T 210 if neighbors == 4: 211 connectivity = radius_neighbors_graph( 212 self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1 213 ) 214 elif neighbors == 8: 215 connectivity = radius_neighbors_graph( 216 self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1 217 ) 218 219 self.connectivity = connectivity 220 self.kwargs = kwargs 221 self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
233def pipelie(observations, info_list, height, width, **kwargs): 234 """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix 235 Steps are: 236 1. Impute missing values 237 2. Scale the features 238 3. Rescale all features to a common range 239 4. Cluster the data using Agglomerative Clustering with connectivity 240 5. Reshape the labels back to the original spatial map shape 241 6. Return the labels and the pipeline object 242 243 Args: 244 observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped 245 info_list (list): A list of dictionaries containing information about each feature 246 height (int): The height of the spatial map 247 width (int): The width of the spatial map 248 kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold 249 250 Returns: 251 np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape 252 Pipeline: The pipeline object containing all the steps of the pipeline 253 """ 254 # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold} 255 256 # imputer strategies 257 no_data_values = [info["NoDataValue"] for info in info_list] 258 no_data_strategies = [info["no_data_strategy"] for info in info_list] 259 fill_values = [info["fill_value"] for info in info_list] 260 weights = [info["weight"] for info in info_list] 261 # scaling_strategies = [info["scaling_strategy"] for info in info_list] 262 263 # scaling strategies 264 index_map = {} 265 for strategy in ["robust", "standard", "onehot"]: 266 index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy] 267 # index_map 268 # !cat config.toml 269 270 # Create transformers for each type 271 robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())]) 272 standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())]) 273 onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))]) 274 # OneHotEncoder._n_features_outs): 275 276 # Combine transformers using ColumnTransformer 277 feature_scaler = ColumnTransformer( 278 transformers=[ 279 ("robust", robust_transformer, index_map["robust"]), 280 ("standard", standard_transformer, index_map["standard"]), 281 ("onehot", onehot_transformer, index_map["onehot"]), 282 ] 283 ) 284 285 # # Create a temporary directory for caching calculations 286 # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS 287 # import tempfile 288 # import joblib 289 # temp_dir = tempfile.mkdtemp() 290 # memory = joblib.Memory(location=temp_dir, verbose=0) 291 292 # Create and apply the pipeline 293 # part 1 until feature scaling 294 pipe1 = Pipeline( 295 # n_features_in_ : int 296 # feature_names_in_ : ndarray of shape (`n_features_in_`,) 297 steps=[ 298 ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)), 299 ("feature_scaling", feature_scaler), 300 ], 301 # memory=memory, 302 verbose=True, 303 ) 304 # map weights to new columns (onehot feature scaler creates one column per category) 305 obs1 = pipe1.fit_transform(observations) 306 cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out() 307 split_names = [name.split("_")[0] for name in cat_names] 308 cat_count = np.unique(split_names, return_counts=True)[1] 309 onehot_map = {} 310 for i, key in enumerate(index_map["onehot"]): 311 onehot_map[key] = cat_count[i] 312 # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])} 313 weight_map = [] 314 for name, idxs in index_map.items(): 315 for idx in idxs: 316 if name == "onehot": 317 weight_map += [weights[idx]] * onehot_map[idx] 318 continue 319 weight_map += [weights[idx]] 320 # part 2 use weight_map and cluster 321 pipe2 = Pipeline( 322 steps=[ 323 ("common_rescaling", RescaleAllToCommonRange(weight_map)), 324 ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)), 325 ], 326 # memory=memory, 327 verbose=True, 328 ) 329 330 # apply pipeLIE 331 labels = pipe2.fit_predict(obs1) 332 333 # Reshape the labels back to the original spatial map shape 334 labels_reshaped = labels.reshape(height, width) 335 return labels_reshaped, pipe1, pipe2
A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix Steps are:
- Impute missing values
- Scale the features
- Rescale all features to a common range
- Cluster the data using Agglomerative Clustering with connectivity
- Reshape the labels back to the original spatial map shape
- Return the labels and the pipeline object
Args: observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped info_list (list): A list of dictionaries containing information about each feature height (int): The height of the spatial map width (int): The width of the spatial map kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold
Returns: np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape Pipeline: The pipeline object containing all the steps of the pipeline
338def write( 339 label_map, 340 width, 341 height, 342 output_raster="", 343 output_poly="output.shp", 344 authid="EPSG:3857", 345 geotransform=(0, 1, 0, 0, 0, 1), 346 nodata=None, 347 feedback=None, 348): 349 from osgeo import gdal, ogr, osr 350 351 from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename 352 353 # setup drivers for raster and polygon output formats 354 if output_raster == "": 355 raster_driver = "MEM" 356 else: 357 try: 358 raster_driver = get_output_raster_format(output_raster, feedback=feedback) 359 except Exception: 360 raster_driver = "GTiff" 361 try: 362 poly_driver = get_vector_driver_from_filename(output_poly) 363 except Exception: 364 poly_driver = "ESRI Shapefile" 365 366 # create raster output 367 src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64) 368 src_ds.SetGeoTransform(geotransform) # != 0 ? 369 src_ds.SetProjection(authid) # != 0 ? 370 # src_band = src_ds.GetRasterBand(1) 371 # if nodata: 372 # src_band.SetNoDataValue(nodata) 373 # src_band.WriteArray(label_map) 374 375 # create polygon output 376 drv = ogr.GetDriverByName(poly_driver) 377 dst_ds = drv.CreateDataSource(output_poly) 378 sp_ref = osr.SpatialReference() 379 sp_ref.SetFromUserInput(authid) # != 0 ? 380 dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon) 381 dst_lyr.CreateField(ogr.FieldDefn("DN", ogr.OFTInteger64)) # != 0 ? 382 dst_lyr.CreateField(ogr.FieldDefn("pixel_count", ogr.OFTInteger64)) 383 # dst_lyr.CreateField(ogr.FieldDefn("area", ogr.OFTInteger)) 384 # dst_lyr.CreateField(ogr.FieldDefn("perimeter", ogr.OFTInteger)) 385 386 # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress) 387 388 # FAIL: All together it merges labels into a single polygon 389 # src_band = src_ds.GetRasterBand(1) 390 # if nodata: 391 # src_band.SetNoDataValue(nodata) 392 # src_band.WriteArray(label_map) 393 # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress) # , ["8CONNECTED=8"]) 394 395 # B separado 396 # for loop for creating each label_map value into a different polygonized feature 397 mem_drv = ogr.GetDriverByName("Memory") 398 tmp_ds = mem_drv.CreateDataSource("tmp_ds") 399 # itera = iter(np.unique(label_map)) 400 # cluster_id = next(itera) 401 areas = [] 402 pixels = [] 403 data = np.zeros_like(label_map) 404 for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)): 405 # temporarily write band 406 src_band = src_ds.GetRasterBand(1) 407 src_band.SetNoDataValue(0) 408 data[label_map == cluster_id] = 1 409 src_band.WriteArray(data) 410 # create feature 411 tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref) 412 gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1) 413 # unset tmp data 414 data[label_map == cluster_id] = 0 415 # set polygon feat 416 feat = tmp_lyr.GetNextFeature() 417 geom = feat.GetGeometryRef() 418 featureDefn = dst_lyr.GetLayerDefn() 419 feature = ogr.Feature(featureDefn) 420 feature.SetGeometry(geom) 421 feature.SetField("DN", float(cluster_id)) 422 areas += [geom.GetArea()] 423 pixels += [px_count] 424 feature.SetField("pixel_count", float(px_count)) 425 # feature.SetField("area", int(geom.GetArea())) 426 # feature.SetField("perimeter", int(geom.Boundary().Length())) 427 dst_lyr.CreateFeature(feature) 428 429 fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger) 430 fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger) 431 # RESTART RASTER 432 # src_ds = None 433 # src_band = None 434 # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, gdal.GDT_Int64) 435 # src_ds.SetGeoTransform(geotransform) # != 0 ? 436 # src_ds.SetProjection(authid) # != 0 ? 437 src_band = src_ds.GetRasterBand(1) 438 if nodata: 439 src_band.SetNoDataValue(nodata) 440 else: 441 # useless paranoia ? 442 src_band.SetNoDataValue(-1) 443 src_band.WriteArray(label_map) 444 # close datasets 445 src_ds.FlushCache() 446 src_ds = None 447 dst_ds.FlushCache() 448 dst_ds = None 449 return True
452def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs): 453 """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. 454 Args: 455 labels_reshaped (np.ndarray): The reshaped labels of the clusters 456 pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps 457 pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps 458 info_list (list): A list of dictionaries containing information about each feature 459 **kargs: Additional keyword arguments 460 n_clusters (int): The number of clusters 461 distance_threshold (float): The linkage distance threshold 462 sieve (int): The number of pixels to use as a sieve filter 463 block (bool): Block the execution until the plot window is closed 464 filename (str): The filename to save the plot 465 """ 466 from matplotlib import pyplot as plt 467 468 no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data 469 pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data 470 471 # filtrar onehot 472 num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"]) 473 num_no_onehots = len(info_list) - num_onehots 474 pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots] 475 476 # indices sin onehots 477 nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"] 478 479 # filtrar onehot de no_data_treated 480 no_data_imputed = no_data_imputed[:, nohots_idxs] 481 482 # reordenados en robust y despues standard 483 rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"] 484 rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"] 485 486 # reordenar rob then std 487 pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs] 488 489 names = [info_list[i]["fname"] for i in nohots_idxs] 490 491 fgs = np.array(plt.rcParams["figure.figsize"]) * 5 492 fig, axs = plt.subplots(3, 2, figsize=fgs) 493 suptitle = "" 494 if n_clusters := kwargs.get("n_clusters"): 495 suptitle = f"n_clusters: {n_clusters}" 496 if distance_threshold := kwargs.get("distance_threshold"): 497 suptitle = f"distance_threshold: {distance_threshold}" 498 if sieve := kwargs.get("sieve"): 499 suptitle += f", sieve: {sieve}" 500 if n_clusters or distance_threshold or sieve: 501 suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}" 502 suptitle += "\n(Not showing categorical data)" 503 fig.suptitle(suptitle) 504 505 # plot violin plot 506 axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True) 507 axs[0, 0].set_title("Violin Plot of NoData Imputed") 508 axs[0, 0].yaxis.grid(True) 509 axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 510 axs[0, 0].set_ylabel("Observed values") 511 512 # plot boxplot 513 axs[0, 1].boxplot(no_data_imputed) 514 axs[0, 1].set_title("Box Plot of NoData Imputed") 515 axs[0, 1].yaxis.grid(True) 516 axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 517 axs[0, 1].set_ylabel("Observed values") 518 519 # plot violin plot 520 axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True) 521 axs[1, 0].set_title("Violin Plot of Common Rescaled") 522 axs[1, 0].yaxis.grid(True) 523 axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 524 axs[0, 1].set_ylabel("Adjusted range") 525 526 # plot boxplot 527 axs[1, 1].boxplot(pre_aggclu_data) 528 axs[1, 1].set_title("Box Plot of Common Rescaled") 529 axs[1, 1].yaxis.grid(True) 530 axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 531 axs[0, 1].set_ylabel("Adjusted range") 532 533 # cluster history 534 unique_labels, counts = np.unique(labels_reshaped, return_counts=True) 535 axs[2, 0].plot(unique_labels, counts, marker="o", color="blue") 536 axs[2, 0].set_title("Cluster history size (in pixels)") 537 axs[2, 0].set_xlabel("Algorithm Step") 538 axs[2, 0].set_ylabel("Size (in pixels)") 539 540 # cluster histogram 541 axs[2, 1].hist(counts, log=True) 542 axs[2, 1].set_xlabel("Cluster Size (in pixels)") 543 axs[2, 1].set_ylabel("Number of Clusters") 544 axs[2, 1].set_title("Histogram of Cluster Sizes") 545 546 plt.tight_layout() 547 if filename := kwargs.get("filename"): 548 logger.info(f"Saving plot to {filename}") 549 plt.savefig(filename) 550 else: 551 if block := kwargs.get("block"): 552 plt.show(block=block) 553 else: 554 plt.show()
Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. Args: labels_reshaped (np.ndarray): The reshaped labels of the clusters pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps info_list (list): A list of dictionaries containing information about each feature **kargs: Additional keyword arguments n_clusters (int): The number of clusters distance_threshold (float): The linkage distance threshold sieve (int): The number of pixels to use as a sieve filter block (bool): Block the execution until the plot window is closed filename (str): The filename to save the plot
557def sieve_filter(data, threshold=2, connectedness=4, feedback=None): 558 """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve 559 Args: 560 data (np.ndarray): The input data to filter 561 threshold (int): The maximum number of pixels in a cluster to keep 562 connectedness (int): The number of connected pixels to consider when filtering 4 or 8 563 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins 564 Returns: 565 np.ndarray: The filtered data 566 """ 567 logger.info("Applying sieve filter") 568 from osgeo import gdal 569 570 height, width = data.shape 571 # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger) 572 num_clusters = len(np.unique(data)) 573 src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, gdal.GDT_Int64) 574 src_band = src_ds.GetRasterBand(1) 575 src_band.WriteArray(data) 576 if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness): 577 fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger) 578 else: 579 sieved = src_band.ReadAsArray() 580 src_band = None 581 src_ds = None 582 num_sieved = len(np.unique(sieved)) 583 # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger) 584 fprint( 585 f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less", 586 level="info", 587 feedback=feedback, 588 logger=logger, 589 ) 590 fprint( 591 "Please try again increasing distance_threshold or reducing n_clusters instead...", 592 level="info", 593 feedback=feedback, 594 logger=logger, 595 ) 596 # from matplotlib import pyplot as plt 597 # fig, (ax1, ax2) = plt.subplots(1, 2) 598 # ax1.imshow(data) 599 # ax1.set_title("before sieve" + str(len(np.unique(data)))) 600 # ax2.imshow(sieved) 601 # ax2.set_title("after sieve" + str(len(np.unique(sieved)))) 602 # plt.show() 603 # data = sieved 604 return sieved
Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve Args: data (np.ndarray): The input data to filter threshold (int): The maximum number of pixels in a cluster to keep connectedness (int): The number of connected pixels to consider when filtering 4 or 8 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins Returns: np.ndarray: The filtered data
607def arg_parser(argv=None): 608 """Parse command line arguments.""" 609 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 610 611 parser = ArgumentParser( 612 description="Agglomerative Clustering with Connectivity for raster data", 613 formatter_class=ArgumentDefaultsHelpFormatter, 614 epilog="More at https://fire2a.github.io/fire2a-lib", 615 ) 616 parser.add_argument( 617 "config_file", 618 nargs="?", 619 type=Path, 620 help="For each raster file, configure its preprocess: nodata & scaling methods", 621 default="config.toml", 622 ) 623 624 aggclu = parser.add_mutually_exclusive_group(required=True) 625 aggclu.add_argument( 626 "-d", 627 "--distance_threshold", 628 type=float, 629 help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)", 630 ) 631 aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters") 632 633 parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="") 634 parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg") 635 parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857") 636 parser.add_argument( 637 "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)" 638 ) 639 parser.add_argument( 640 "-nw", 641 "--no_write", 642 action="store_true", 643 help="Do not write outputs raster nor polygons", 644 default=False, 645 ) 646 parser.add_argument( 647 "-s", 648 "--script", 649 action="store_true", 650 help="Run in script mode, returning the label_map and the pipeline object", 651 default=False, 652 ) 653 parser.add_argument( 654 "--sieve", 655 type=int, 656 help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor", 657 ) 658 parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3") 659 660 plot = parser.add_argument_group( 661 "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms" 662 ) 663 plot.add_argument( 664 "-p", 665 "--plots", 666 action="store_true", 667 help="Activate the plotting routines", 668 ) 669 plot.add_argument( 670 "-b", 671 "--block", 672 action="store_false", 673 default=True, 674 help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS", 675 ) 676 plot.add_argument( 677 "-f", 678 "--filename", 679 type=str, 680 help="Filename to save the plot. If not provided, matplotlib will raise a window", 681 ) 682 args = parser.parse_args(argv) 683 args.geotransform = tuple(map(float, args.geotransform[1:-1].split(","))) 684 if Path(args.config_file).is_file() is False: 685 parser.error(f"File {args.config_file} not found") 686 return args
Parse command line arguments.
689def main(argv=None): 690 """ 691 692 args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) 693 args = arg_parser(["-d","10.0"]]) 694 args = arg_parser(["-d","10.0", "config2.toml"]) 695 args = arg_parser(["-n","10"]) 696 """ 697 if argv is sys.argv: 698 argv = sys.argv[1:] 699 args = arg_parser(argv) 700 701 if args.verbose != 0: 702 global logger 703 from fire2a import setup_logger 704 705 logger = setup_logger(verbosity=args.verbose) 706 707 logger.info("args %s", args) 708 709 # 2 LEE CONFIG 710 config = read_toml(args.config_file) 711 # logger.debug(config) 712 713 # 2.1 ADD DEFAULTS 714 for filename, file_config in config.items(): 715 if "no_data_strategy" not in file_config: 716 config[filename]["no_data_strategy"] = "mean" 717 if "scaling_strategy" not in file_config: 718 config[filename]["scaling_strategy"] = "robust" 719 if "fill_value" not in file_config: 720 config[filename]["fill_value"] = 0 721 if "weight" not in file_config: 722 config[filename]["weight"] = 1 723 logger.debug(config) 724 725 # 3. LEE DATA 726 from fire2a.raster import read_raster 727 728 data_list, info_list = [], [] 729 for filename, file_config in config.items(): 730 data, info = read_raster(filename) 731 info["fname"] = Path(filename).name 732 info["no_data_strategy"] = file_config["no_data_strategy"] 733 info["scaling_strategy"] = file_config["scaling_strategy"] 734 info["fill_value"] = file_config["fill_value"] 735 info["weight"] = file_config["weight"] 736 data_list += [data] 737 info_list += [info] 738 logger.debug("%s", data[:2, :2]) 739 logger.debug("%s", info) 740 741 # 4. VALIDAR 2d todos mismo shape 742 height, width = check_shapes(data_list) 743 744 # 5. lista[mapas] -> OBSERVACIONES 745 observations = np.column_stack([data.ravel() for data in data_list]) 746 747 # 6. nodata -> feature scaling -> all scaling -> clustering 748 labels_reshaped, pipe1, pipe2 = pipelie( 749 observations, 750 info_list, 751 height, 752 width, 753 n_clusters=args.n_clusters, 754 distance_threshold=args.distance_threshold, 755 ) # insert more keyworded arguments to the clustering algorithm here! 756 757 # SIEVE 758 if args.sieve: 759 logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}") 760 labels_reshaped = sieve_filter(labels_reshaped, args.sieve) 761 762 logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}") 763 764 # 7 debbuging plots 765 if args.plots: 766 plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args)) 767 768 # 8. ESCRIBIR RASTER 769 if not args.no_write: 770 if not write( 771 labels_reshaped, 772 width, 773 height, 774 output_raster=args.output_raster, 775 output_poly=args.output_poly, 776 authid=args.authid, 777 geotransform=args.geotransform, 778 ): 779 logger.error("Error writing output raster") 780 781 # 9. SCRIPT MODE 782 if args.script: 783 return labels_reshaped, pipe1, pipe2 784 785 return 0
args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) args = arg_parser(["-d","10.0"]]) args = arg_parser(["-d","10.0", "config2.toml"]) args = arg_parser(["-n","10"])