fire2a.agglomerative_clustering
ππ π²π₯
Raster clustering
Usage
Overview
- Choose your raster files
- Configure nodata, scaling strategies and weights in the
config.toml
file - Choose "distance threshold" (or "number of clusters") for the Agglomerative clustering algorithm. Recommended:
- Start with a distance threshold of 10.0 and decrease for more or increase for less clusters
- After calibrating the distance threshold;
- Sieve small clusters (merge them to the biggest neighbor) with the
--sieve integer_pixels_size
option
Execution
# get command line help
python -m fire2a.agglomerative_clustering -h
python -m fire2a.agglomerative_clustering --help
# activate your qgis dev environment
source ~/pyqgisdev/bin/activate
# execute
(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0
# windowsπ© users should use QGIS's python
C:\PROGRA~1\QGIS33~1.3\bin\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0
More info on: How to windows π© using qgis's python
Preparation
1. Choose your raster files
- Any GDAL compatible raster will be read
- Place them all in the same directory where the script will be executed
- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9]
2. Preprocessing configuration
See the config.toml
file for example of the configuration of the preprocessing steps. The file is structured as follows:
["filename.tif"]
no_data_strategy = "most_frequent"
scaling_strategy = "onehot"
fill_value = 0
weight = 1
__scaling_strategy__
__no_data_strategy__
- can be "mean", "median", "most_frequent", "constant"
- default is "mean"
- categorical data should use "most_frequent" or "constant"
- "constant" will use the value in __fill_value__ (see below)
- SimpleImputer
__fill_value__
- used when __no_data_strategy__ is "constant"
- default is 0
- SimpleImputer
__weight__
- default is 1
- used to give more importance to some features than others
- This is done after the nodata imputation and scaling steps, before clustering
3. Clustering configuration
- __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive:
-n
or--n_clusters
: The number of clusters to form as well as the number of centroids to generate.-d
or--distance_threshold
: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm).- More parameters for clustering can be passed directly into the pipelie method as keyword arguments
- __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the GDAL sieve library
--sieve
: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor
4. Post-processing
Outputs can be:
- A raster file with the cluster labels and a polygon file with the cluster polygons
- A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster
- A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows)
Or use the --script
option to return the label_map and the pipeline object for further processing in another python script:
from fire2a.agglomerative_clustering import main
label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"])
1#!/usr/bin/env python3 2# fmt: off 3"""ππ π²π₯ 4# Raster clustering 5## Usage 6### Overview 71. Choose your raster files 82. Configure nodata, scaling strategies and weights in the `config.toml` file 93. Choose "distance threshold" (or "number of clusters") for the [Agglomerative](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) clustering algorithm. Recommended: 10 - Start with a distance threshold of 10.0 and decrease for more or increase for less clusters 11 - After calibrating the distance threshold; 12 - [Sieve](https://gdal.org/en/latest/programs/gdal_sieve.html) small clusters (merge them to the biggest neighbor) with the `--sieve integer_pixels_size` option 13 14### Execution 15```bash 16# get command line help 17python -m fire2a.agglomerative_clustering -h 18python -m fire2a.agglomerative_clustering --help 19 20# activate your qgis dev environment 21source ~/pyqgisdev/bin/activate 22# execute 23(qgis) $ python -m fire2a.agglomerative_clustering -d 10.0 24 25# windowsπ© users should use QGIS's python 26C:\\PROGRA~1\\QGIS33~1.3\\bin\\python-qgis.bat -m fire2a.agglomerative_clustering -d 10.0 27``` 28[More info on: How to windows π© using qgis's python](https://github.com/fire2a/fire2a-lib/tree/main/qgis-launchers) 29 30### Preparation 31#### 1. Choose your raster files 32- Any [GDAL compatible](https://gdal.org/en/latest/drivers/raster/index.html) raster will be read 33- Place them all in the same directory where the script will be executed 34- "Quote them" if they have any non alphanumerical chars [a-zA-Z0-9] 35 36#### 2. Preprocessing configuration 37See the `config.toml` file for example of the configuration of the preprocessing steps. The file is structured as follows: 38 39```toml 40["filename.tif"] 41no_data_strategy = "most_frequent" 42scaling_strategy = "onehot" 43fill_value = 0 44weight = 1 45``` 46 471. __scaling_strategy__ 48 - can be "standard", "robust", "onehot" 49 - default is "robust" 50 - [Standard](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html): (x-mean)/stddev 51 - [Robust](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html): same but droping the tails of the distribution 52 - [OneHot](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html): __for CATEGORICAL DATA__ 53 542. __no_data_strategy__ 55 - can be "mean", "median", "most_frequent", "constant" 56 - default is "mean" 57 - categorical data should use "most_frequent" or "constant" 58 - "constant" will use the value in __fill_value__ (see below) 59 - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html) 60 613. __fill_value__ 62 - used when __no_data_strategy__ is "constant" 63 - default is 0 64 - [SimpleImputer](https://scikit-learn.org/stable/modules/generated/sklearn.impute.SimpleImputer.html) 65 664. __weight__ 67 - default is 1 68 - used to give more importance to some features than others 69 - This is done after the nodata imputation and scaling steps, before clustering 70 71 72#### 3. Clustering configuration 73 741. __Agglomerative__ clustering algorithm is used. The following parameters are muttually exclusive: 75- `-n` or `--n_clusters`: The number of clusters to form as well as the number of centroids to generate. 76- `-d` or `--distance_threshold`: The linkage distance threshold above which, clusters will not be merged. When scaling start with 10.0 and downward (0.0 is compute the whole algorithm). 77- More [parameters](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AgglomerativeClustering.html) for clustering can be passed directly into the pipelie method as keyword arguments 78 792. __Sieve filter__ is applied to remove small clusters. The sieve filter is applied using the [GDAL sieve library](https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve) 80- `--sieve`: Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor 81 82#### 4. Post-processing 83Outputs can be: 84- A raster file with the cluster labels and a polygon file with the cluster polygons 85- A polygon file with the cluster polygons, with attribute being the number of pixels in each cluster 86- A plot of the input data distributions, the rescaled data distributions, and the cluster size history and histogram (crashes QGIS in windows) 87 88Or use the `--script` option to return the label_map and the pipeline object for further processing in another python script: 89 ```python 90 from fire2a.agglomerative_clustering import main 91 label_map, pipe1, pipe2 = main(["-d", "10.0", "-s"]) 92 ``` 93""" 94# fmt: on 95# from IPython.terminal.embed import InteractiveShellEmbed 96# InteractiveShellEmbed()() 97import logging 98import sys 99from pathlib import Path 100 101import numpy as np 102from osgeo import gdal, ogr, osr 103from sklearn.base import BaseEstimator, TransformerMixin 104from sklearn.cluster import AgglomerativeClustering 105from sklearn.compose import ColumnTransformer 106from sklearn.impute import SimpleImputer 107from sklearn.neighbors import radius_neighbors_graph 108from sklearn.pipeline import Pipeline 109from sklearn.preprocessing import OneHotEncoder, RobustScaler, StandardScaler 110 111from fire2a.utils import fprint, read_toml 112 113try: 114 GDT = gdal.GDT_Int64 115except: 116 GDT = gdal.GDT_Int32 117try: 118 OFT = ogr.OFTInteger64 119except: 120 OFT = ogr.OFTInteger 121 122logger = logging.getLogger(__name__) 123 124 125def check_shapes(data_list): 126 """Check if all data arrays have the same shape and are 2D. 127 Returns the shape of the data arrays if they are all equal. 128 """ 129 from functools import reduce 130 131 def equal_or_error(x, y): 132 """Check if x and y are equal, returns x if equal else raises a ValueError.""" 133 if x == y: 134 return x 135 else: 136 raise ValueError("All data arrays must have the same shape") 137 138 shape = reduce(equal_or_error, (data.shape for data in data_list)) 139 if len(shape) != 2: 140 raise ValueError("All data arrays must be 2D") 141 height, width = shape 142 return height, width 143 144 145def get_map_neighbors(height, width, num_neighbors=8): 146 """Get the neighbors of each cell in a 2D grid. 147 n_jobs=-1 uses all available cores. 148 """ 149 150 grid_points = np.indices((height, width)).reshape(2, -1).T 151 152 nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1) 153 nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1) 154 155 # assert nb4.shape[0] == width * height 156 # assert nb8.shape[1] == width * height 157 # for n in range(width * height): 158 # _, neighbors = np.nonzero(nb4[n]) 159 # assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}" 160 # assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}" 161 return nb4, nb8 162 163 164class NoDataImputer(BaseEstimator, TransformerMixin): 165 """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column""" 166 167 def __init__(self, no_data_values, strategies, constants): 168 self.no_data_values = no_data_values 169 self.strategies = strategies 170 self.constants = constants 171 self.imputers = [] 172 for no_data_value, strategy, constant in zip(no_data_values, strategies, constants): 173 if no_data_value: 174 self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)] 175 else: 176 self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)] 177 178 def fit(self, X, y=None): 179 for i, imputer in enumerate(self.imputers): 180 imputer.fit(X[:, [i]], y) 181 return self 182 183 def transform(self, X): 184 for i, imputer in enumerate(self.imputers): 185 X[:, [i]] = imputer.transform(X[:, [i]]) 186 self.output_data = X 187 return X 188 189 190class RescaleAllToCommonRange(BaseEstimator, TransformerMixin): 191 """A custom transformer that rescales all features to a common range [0, 1]""" 192 193 def __init__(self, weight_map): 194 self.weight_map = weight_map 195 196 def fit(self, X, y=None): 197 # Determine the combined range of all scaled features 198 self.min_val = [x.min() for x in X.T] 199 self.max_val = [x.max() for x in X.T] 200 return self 201 202 def transform(self, X): 203 # Rescale all features to match the common range 204 for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)): 205 if ma - mi == 0: 206 X.T[i] = x * self.weight_map[i] 207 else: 208 X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i] 209 return X 210 211 212class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin): 213 def __init__(self, height, width, neighbors=4, **kwargs): 214 self.height = height 215 self.width = width 216 self.neighbors = neighbors 217 218 self.grid_points = np.indices((height, width)).reshape(2, -1).T 219 if neighbors == 4: 220 connectivity = radius_neighbors_graph( 221 self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1 222 ) 223 elif neighbors == 8: 224 connectivity = radius_neighbors_graph( 225 self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1 226 ) 227 228 self.connectivity = connectivity 229 self.kwargs = kwargs 230 self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs) 231 232 def fit(self, X, y=None): 233 logger.debug("not sure why, but this method is never called alas needed") 234 self.model.fit(X) 235 return self 236 237 def fit_predict(self, X, y=None): 238 self.input_data = X 239 return self.model.fit_predict(X) 240 241 242def pipelie(observations, info_list, height, width, **kwargs): 243 """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix 244 Steps are: 245 1. Impute missing values 246 2. Scale the features 247 3. Rescale all features to a common range 248 4. Cluster the data using Agglomerative Clustering with connectivity 249 5. Reshape the labels back to the original spatial map shape 250 6. Return the labels and the pipeline object 251 252 Args: 253 observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped 254 info_list (list): A list of dictionaries containing information about each feature 255 height (int): The height of the spatial map 256 width (int): The width of the spatial map 257 kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold 258 259 Returns: 260 np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape 261 Pipeline: The pipeline object containing all the steps of the pipeline 262 """ 263 # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold} 264 265 # imputer strategies 266 no_data_values = [info["NoDataValue"] for info in info_list] 267 no_data_strategies = [info["no_data_strategy"] for info in info_list] 268 fill_values = [info["fill_value"] for info in info_list] 269 weights = [info["weight"] for info in info_list] 270 # scaling_strategies = [info["scaling_strategy"] for info in info_list] 271 272 # scaling strategies 273 index_map = {} 274 for strategy in ["robust", "standard", "onehot"]: 275 index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy] 276 # index_map 277 # !cat config.toml 278 279 # Create transformers for each type 280 robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())]) 281 standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())]) 282 onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))]) 283 # OneHotEncoder._n_features_outs): 284 285 # Combine transformers using ColumnTransformer 286 feature_scaler = ColumnTransformer( 287 transformers=[ 288 ("robust", robust_transformer, index_map["robust"]), 289 ("standard", standard_transformer, index_map["standard"]), 290 ("onehot", onehot_transformer, index_map["onehot"]), 291 ] 292 ) 293 294 # # Create a temporary directory for caching calculations 295 # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS 296 # import tempfile 297 # import joblib 298 # temp_dir = tempfile.mkdtemp() 299 # memory = joblib.Memory(location=temp_dir, verbose=0) 300 301 # Create and apply the pipeline 302 # part 1 until feature scaling 303 pipe1 = Pipeline( 304 # n_features_in_ : int 305 # feature_names_in_ : ndarray of shape (`n_features_in_`,) 306 steps=[ 307 ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)), 308 ("feature_scaling", feature_scaler), 309 ], 310 # memory=memory, 311 verbose=True, 312 ) 313 # map weights to new columns (onehot feature scaler creates one column per category) 314 obs1 = pipe1.fit_transform(observations) 315 cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out() 316 split_names = [name.split("_")[0] for name in cat_names] 317 cat_count = np.unique(split_names, return_counts=True)[1] 318 onehot_map = {} 319 for i, key in enumerate(index_map["onehot"]): 320 onehot_map[key] = cat_count[i] 321 # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])} 322 weight_map = [] 323 for name, idxs in index_map.items(): 324 for idx in idxs: 325 if name == "onehot": 326 weight_map += [weights[idx]] * onehot_map[idx] 327 continue 328 weight_map += [weights[idx]] 329 # part 2 use weight_map and cluster 330 pipe2 = Pipeline( 331 steps=[ 332 ("common_rescaling", RescaleAllToCommonRange(weight_map)), 333 ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)), 334 ], 335 # memory=memory, 336 verbose=True, 337 ) 338 339 # apply pipeLIE 340 labels = pipe2.fit_predict(obs1) 341 342 # Reshape the labels back to the original spatial map shape 343 labels_reshaped = labels.reshape(height, width) 344 return labels_reshaped, pipe1, pipe2 345 346 347def write( 348 label_map, 349 width, 350 height, 351 output_raster="", 352 output_poly="output.shp", 353 authid="EPSG:3857", 354 geotransform=(0, 1, 0, 0, 0, 1), 355 nodata=None, 356 feedback=None, 357): 358 359 from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename 360 361 # setup drivers for raster and polygon output formats 362 if output_raster == "": 363 raster_driver = "MEM" 364 else: 365 try: 366 raster_driver = get_output_raster_format(output_raster, feedback=feedback) 367 except Exception: 368 raster_driver = "GTiff" 369 try: 370 poly_driver = get_vector_driver_from_filename(output_poly) 371 except Exception: 372 poly_driver = "ESRI Shapefile" 373 374 # create raster output 375 src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT) 376 src_ds.SetGeoTransform(geotransform) # != 0 ? 377 src_ds.SetProjection(authid) # != 0 ? 378 # src_band = src_ds.GetRasterBand(1) 379 # if nodata: 380 # src_band.SetNoDataValue(nodata) 381 # src_band.WriteArray(label_map) 382 383 # create polygon output 384 drv = ogr.GetDriverByName(poly_driver) 385 dst_ds = drv.CreateDataSource(output_poly) 386 sp_ref = osr.SpatialReference() 387 sp_ref.SetFromUserInput(authid) # != 0 ? 388 dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon) 389 dst_lyr.CreateField(ogr.FieldDefn("DN", OFT)) # != 0 ? 390 dst_lyr.CreateField(ogr.FieldDefn("pixel_count", OFT)) 391 # dst_lyr.CreateField(ogr.FieldDefn("area", OFT)) 392 # dst_lyr.CreateField(ogr.FieldDefn("perimeter", OFT)) 393 394 # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress) 395 396 # FAIL: All together it merges labels into a single polygon 397 # src_band = src_ds.GetRasterBand(1) 398 # if nodata: 399 # src_band.SetNoDataValue(nodata) 400 # src_band.WriteArray(label_map) 401 # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress) # , ["8CONNECTED=8"]) 402 403 # B separado 404 # for loop for creating each label_map value into a different polygonized feature 405 mem_drv = ogr.GetDriverByName("Memory") 406 tmp_ds = mem_drv.CreateDataSource("tmp_ds") 407 # itera = iter(np.unique(label_map)) 408 # cluster_id = next(itera) 409 areas = [] 410 pixels = [] 411 data = np.zeros_like(label_map) 412 for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)): 413 # temporarily write band 414 src_band = src_ds.GetRasterBand(1) 415 src_band.SetNoDataValue(0) 416 data[label_map == cluster_id] = 1 417 src_band.WriteArray(data) 418 # create feature 419 tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref) 420 gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1) 421 # unset tmp data 422 data[label_map == cluster_id] = 0 423 # set polygon feat 424 feat = tmp_lyr.GetNextFeature() 425 geom = feat.GetGeometryRef() 426 featureDefn = dst_lyr.GetLayerDefn() 427 feature = ogr.Feature(featureDefn) 428 feature.SetGeometry(geom) 429 feature.SetField("DN", float(cluster_id)) 430 areas += [geom.GetArea()] 431 pixels += [px_count] 432 feature.SetField("pixel_count", float(px_count)) 433 # feature.SetField("area", int(geom.GetArea())) 434 # feature.SetField("perimeter", int(geom.Boundary().Length())) 435 dst_lyr.CreateFeature(feature) 436 437 fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger) 438 fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger) 439 # RESTART RASTER 440 # src_ds = None 441 # src_band = None 442 # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT) 443 # src_ds.SetGeoTransform(geotransform) # != 0 ? 444 # src_ds.SetProjection(authid) # != 0 ? 445 src_band = src_ds.GetRasterBand(1) 446 if nodata: 447 src_band.SetNoDataValue(nodata) 448 else: 449 # useless paranoia ? 450 src_band.SetNoDataValue(-1) 451 src_band.WriteArray(label_map) 452 # close datasets 453 src_ds.FlushCache() 454 src_ds = None 455 dst_ds.FlushCache() 456 dst_ds = None 457 return True 458 459 460def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs): 461 """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. 462 Args: 463 labels_reshaped (np.ndarray): The reshaped labels of the clusters 464 pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps 465 pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps 466 info_list (list): A list of dictionaries containing information about each feature 467 **kargs: Additional keyword arguments 468 n_clusters (int): The number of clusters 469 distance_threshold (float): The linkage distance threshold 470 sieve (int): The number of pixels to use as a sieve filter 471 block (bool): Block the execution until the plot window is closed 472 filename (str): The filename to save the plot 473 """ 474 from matplotlib import pyplot as plt 475 476 no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data 477 pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data 478 479 # filtrar onehot 480 num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"]) 481 num_no_onehots = len(info_list) - num_onehots 482 pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots] 483 484 # indices sin onehots 485 nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"] 486 487 # filtrar onehot de no_data_treated 488 no_data_imputed = no_data_imputed[:, nohots_idxs] 489 490 # reordenados en robust y despues standard 491 rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"] 492 rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"] 493 494 # reordenar rob then std 495 pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs] 496 497 names = [info_list[i]["fname"] for i in nohots_idxs] 498 499 fgs = np.array(plt.rcParams["figure.figsize"]) * 5 500 fig, axs = plt.subplots(3, 2, figsize=fgs) 501 suptitle = "" 502 if n_clusters := kwargs.get("n_clusters"): 503 suptitle = f"n_clusters: {n_clusters}" 504 if distance_threshold := kwargs.get("distance_threshold"): 505 suptitle = f"distance_threshold: {distance_threshold}" 506 if sieve := kwargs.get("sieve"): 507 suptitle += f", sieve: {sieve}" 508 if n_clusters or distance_threshold or sieve: 509 suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}" 510 suptitle += "\n(Not showing categorical data)" 511 fig.suptitle(suptitle) 512 513 # plot violin plot 514 axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True) 515 axs[0, 0].set_title("Violin Plot of NoData Imputed") 516 axs[0, 0].yaxis.grid(True) 517 axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 518 axs[0, 0].set_ylabel("Observed values") 519 520 # plot boxplot 521 axs[0, 1].boxplot(no_data_imputed) 522 axs[0, 1].set_title("Box Plot of NoData Imputed") 523 axs[0, 1].yaxis.grid(True) 524 axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 525 axs[0, 1].set_ylabel("Observed values") 526 527 # plot violin plot 528 axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True) 529 axs[1, 0].set_title("Violin Plot of Common Rescaled") 530 axs[1, 0].yaxis.grid(True) 531 axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 532 axs[0, 1].set_ylabel("Adjusted range") 533 534 # plot boxplot 535 axs[1, 1].boxplot(pre_aggclu_data) 536 axs[1, 1].set_title("Box Plot of Common Rescaled") 537 axs[1, 1].yaxis.grid(True) 538 axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 539 axs[0, 1].set_ylabel("Adjusted range") 540 541 # cluster history 542 unique_labels, counts = np.unique(labels_reshaped, return_counts=True) 543 axs[2, 0].plot(unique_labels, counts, marker="o", color="blue") 544 axs[2, 0].set_title("Cluster history size (in pixels)") 545 axs[2, 0].set_xlabel("Algorithm Step") 546 axs[2, 0].set_ylabel("Size (in pixels)") 547 548 # cluster histogram 549 axs[2, 1].hist(counts, log=True) 550 axs[2, 1].set_xlabel("Cluster Size (in pixels)") 551 axs[2, 1].set_ylabel("Number of Clusters") 552 axs[2, 1].set_title("Histogram of Cluster Sizes") 553 554 plt.tight_layout() 555 if filename := kwargs.get("filename"): 556 logger.info(f"Saving plot to {filename}") 557 plt.savefig(filename) 558 else: 559 if block := kwargs.get("block"): 560 plt.show(block=block) 561 else: 562 plt.show() 563 564 565def sieve_filter(data, threshold=2, connectedness=4, feedback=None): 566 """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve 567 Args: 568 data (np.ndarray): The input data to filter 569 threshold (int): The maximum number of pixels in a cluster to keep 570 connectedness (int): The number of connected pixels to consider when filtering 4 or 8 571 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins 572 Returns: 573 np.ndarray: The filtered data 574 """ 575 logger.info("Applying sieve filter") 576 577 height, width = data.shape 578 # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger) 579 num_clusters = len(np.unique(data)) 580 src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, GDT) 581 src_band = src_ds.GetRasterBand(1) 582 src_band.WriteArray(data) 583 if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness): 584 fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger) 585 else: 586 sieved = src_band.ReadAsArray() 587 src_band = None 588 src_ds = None 589 num_sieved = len(np.unique(sieved)) 590 # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger) 591 fprint( 592 f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less", 593 level="info", 594 feedback=feedback, 595 logger=logger, 596 ) 597 fprint( 598 "Please try again increasing distance_threshold or reducing n_clusters instead...", 599 level="info", 600 feedback=feedback, 601 logger=logger, 602 ) 603 # from matplotlib import pyplot as plt 604 # fig, (ax1, ax2) = plt.subplots(1, 2) 605 # ax1.imshow(data) 606 # ax1.set_title("before sieve" + str(len(np.unique(data)))) 607 # ax2.imshow(sieved) 608 # ax2.set_title("after sieve" + str(len(np.unique(sieved)))) 609 # plt.show() 610 # data = sieved 611 return sieved 612 613 614def arg_parser(argv=None): 615 """Parse command line arguments.""" 616 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 617 618 parser = ArgumentParser( 619 description="Agglomerative Clustering with Connectivity for raster data", 620 formatter_class=ArgumentDefaultsHelpFormatter, 621 epilog="More at https://fire2a.github.io/fire2a-lib", 622 ) 623 parser.add_argument( 624 "config_file", 625 nargs="?", 626 type=Path, 627 help="For each raster file, configure its preprocess: nodata & scaling methods", 628 default="config.toml", 629 ) 630 631 aggclu = parser.add_mutually_exclusive_group(required=True) 632 aggclu.add_argument( 633 "-d", 634 "--distance_threshold", 635 type=float, 636 help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)", 637 ) 638 aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters") 639 640 parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="") 641 parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg") 642 parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857") 643 parser.add_argument( 644 "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)" 645 ) 646 parser.add_argument( 647 "-nw", 648 "--no_write", 649 action="store_true", 650 help="Do not write outputs raster nor polygons", 651 default=False, 652 ) 653 parser.add_argument( 654 "-s", 655 "--script", 656 action="store_true", 657 help="Run in script mode, returning the label_map and the pipeline object", 658 default=False, 659 ) 660 parser.add_argument( 661 "--sieve", 662 type=int, 663 help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor", 664 ) 665 parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3") 666 667 plot = parser.add_argument_group( 668 "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms" 669 ) 670 plot.add_argument( 671 "-p", 672 "--plots", 673 action="store_true", 674 help="Activate the plotting routines", 675 ) 676 plot.add_argument( 677 "-b", 678 "--block", 679 action="store_false", 680 default=True, 681 help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS", 682 ) 683 plot.add_argument( 684 "-f", 685 "--filename", 686 type=str, 687 help="Filename to save the plot. If not provided, matplotlib will raise a window", 688 ) 689 args = parser.parse_args(argv) 690 args.geotransform = tuple(map(float, args.geotransform[1:-1].split(","))) 691 if Path(args.config_file).is_file() is False: 692 parser.error(f"File {args.config_file} not found") 693 return args 694 695 696def main(argv=None): 697 """ 698 699 args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) 700 args = arg_parser(["-d","10.0"]]) 701 args = arg_parser(["-d","10.0", "config2.toml"]) 702 args = arg_parser(["-n","10"]) 703 """ 704 if argv is sys.argv: 705 argv = sys.argv[1:] 706 args = arg_parser(argv) 707 708 if args.verbose != 0: 709 global logger 710 from fire2a import setup_logger 711 712 logger = setup_logger(verbosity=args.verbose) 713 714 logger.info("args %s", args) 715 716 # 2 LEE CONFIG 717 config = read_toml(args.config_file) 718 # logger.debug(config) 719 720 # 2.1 ADD DEFAULTS 721 for filename, file_config in config.items(): 722 if "no_data_strategy" not in file_config: 723 config[filename]["no_data_strategy"] = "mean" 724 if "scaling_strategy" not in file_config: 725 config[filename]["scaling_strategy"] = "robust" 726 if "fill_value" not in file_config: 727 config[filename]["fill_value"] = 0 728 if "weight" not in file_config: 729 config[filename]["weight"] = 1 730 logger.debug(config) 731 732 # 3. LEE DATA 733 from fire2a.raster import read_raster 734 735 data_list, info_list = [], [] 736 for filename, file_config in config.items(): 737 data, info = read_raster(filename) 738 info["fname"] = Path(filename).name 739 info["no_data_strategy"] = file_config["no_data_strategy"] 740 info["scaling_strategy"] = file_config["scaling_strategy"] 741 info["fill_value"] = file_config["fill_value"] 742 info["weight"] = file_config["weight"] 743 data_list += [data] 744 info_list += [info] 745 logger.debug("%s", data[:2, :2]) 746 logger.debug("%s", info) 747 748 # 4. VALIDAR 2d todos mismo shape 749 height, width = check_shapes(data_list) 750 751 # 5. lista[mapas] -> OBSERVACIONES 752 observations = np.column_stack([data.ravel() for data in data_list]) 753 754 # 6. nodata -> feature scaling -> all scaling -> clustering 755 labels_reshaped, pipe1, pipe2 = pipelie( 756 observations, 757 info_list, 758 height, 759 width, 760 n_clusters=args.n_clusters, 761 distance_threshold=args.distance_threshold, 762 ) # insert more keyworded arguments to the clustering algorithm here! 763 764 # SIEVE 765 if args.sieve: 766 logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}") 767 labels_reshaped = sieve_filter(labels_reshaped, args.sieve) 768 769 logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}") 770 771 # 7 debbuging plots 772 if args.plots: 773 plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args)) 774 775 # 8. ESCRIBIR RASTER 776 if not args.no_write: 777 if not write( 778 labels_reshaped, 779 width, 780 height, 781 output_raster=args.output_raster, 782 output_poly=args.output_poly, 783 authid=args.authid, 784 geotransform=args.geotransform, 785 ): 786 logger.error("Error writing output raster") 787 788 # 9. SCRIPT MODE 789 if args.script: 790 return labels_reshaped, pipe1, pipe2 791 792 return 0 793 794 795if __name__ == "__main__": 796 sys.exit(main(sys.argv))
126def check_shapes(data_list): 127 """Check if all data arrays have the same shape and are 2D. 128 Returns the shape of the data arrays if they are all equal. 129 """ 130 from functools import reduce 131 132 def equal_or_error(x, y): 133 """Check if x and y are equal, returns x if equal else raises a ValueError.""" 134 if x == y: 135 return x 136 else: 137 raise ValueError("All data arrays must have the same shape") 138 139 shape = reduce(equal_or_error, (data.shape for data in data_list)) 140 if len(shape) != 2: 141 raise ValueError("All data arrays must be 2D") 142 height, width = shape 143 return height, width
Check if all data arrays have the same shape and are 2D. Returns the shape of the data arrays if they are all equal.
146def get_map_neighbors(height, width, num_neighbors=8): 147 """Get the neighbors of each cell in a 2D grid. 148 n_jobs=-1 uses all available cores. 149 """ 150 151 grid_points = np.indices((height, width)).reshape(2, -1).T 152 153 nb4 = radius_neighbors_graph(grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1) 154 nb8 = radius_neighbors_graph(grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1) 155 156 # assert nb4.shape[0] == width * height 157 # assert nb8.shape[1] == width * height 158 # for n in range(width * height): 159 # _, neighbors = np.nonzero(nb4[n]) 160 # assert 2<= len(neighbors) <= 4, f"{n=} {neighbors=}" 161 # assert 3<= len(neighbors) <= 8, f"{n=} {neighbors=}" 162 return nb4, nb8
Get the neighbors of each cell in a 2D grid. n_jobs=-1 uses all available cores.
165class NoDataImputer(BaseEstimator, TransformerMixin): 166 """A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column""" 167 168 def __init__(self, no_data_values, strategies, constants): 169 self.no_data_values = no_data_values 170 self.strategies = strategies 171 self.constants = constants 172 self.imputers = [] 173 for no_data_value, strategy, constant in zip(no_data_values, strategies, constants): 174 if no_data_value: 175 self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)] 176 else: 177 self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)] 178 179 def fit(self, X, y=None): 180 for i, imputer in enumerate(self.imputers): 181 imputer.fit(X[:, [i]], y) 182 return self 183 184 def transform(self, X): 185 for i, imputer in enumerate(self.imputers): 186 X[:, [i]] = imputer.transform(X[:, [i]]) 187 self.output_data = X 188 return X
A custom Imputer that treats a specified nodata_value as np.nan and supports different strategies per column
168 def __init__(self, no_data_values, strategies, constants): 169 self.no_data_values = no_data_values 170 self.strategies = strategies 171 self.constants = constants 172 self.imputers = [] 173 for no_data_value, strategy, constant in zip(no_data_values, strategies, constants): 174 if no_data_value: 175 self.imputers += [SimpleImputer(strategy=strategy, missing_values=no_data_value, fill_value=constant)] 176 else: 177 self.imputers += [SimpleImputer(strategy=strategy, fill_value=constant)]
191class RescaleAllToCommonRange(BaseEstimator, TransformerMixin): 192 """A custom transformer that rescales all features to a common range [0, 1]""" 193 194 def __init__(self, weight_map): 195 self.weight_map = weight_map 196 197 def fit(self, X, y=None): 198 # Determine the combined range of all scaled features 199 self.min_val = [x.min() for x in X.T] 200 self.max_val = [x.max() for x in X.T] 201 return self 202 203 def transform(self, X): 204 # Rescale all features to match the common range 205 for i, (x, mi, ma) in enumerate(zip(X.T, self.min_val, self.max_val)): 206 if ma - mi == 0: 207 X.T[i] = x * self.weight_map[i] 208 else: 209 X.T[i] = (x - mi) / (ma - mi) * self.weight_map[i] 210 return X
A custom transformer that rescales all features to a common range [0, 1]
213class CustomAgglomerativeClustering(BaseEstimator, TransformerMixin): 214 def __init__(self, height, width, neighbors=4, **kwargs): 215 self.height = height 216 self.width = width 217 self.neighbors = neighbors 218 219 self.grid_points = np.indices((height, width)).reshape(2, -1).T 220 if neighbors == 4: 221 connectivity = radius_neighbors_graph( 222 self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1 223 ) 224 elif neighbors == 8: 225 connectivity = radius_neighbors_graph( 226 self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1 227 ) 228 229 self.connectivity = connectivity 230 self.kwargs = kwargs 231 self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs) 232 233 def fit(self, X, y=None): 234 logger.debug("not sure why, but this method is never called alas needed") 235 self.model.fit(X) 236 return self 237 238 def fit_predict(self, X, y=None): 239 self.input_data = X 240 return self.model.fit_predict(X)
Base class for all estimators in scikit-learn.
Inheriting from this class provides default implementations of:
- setting and getting parameters used by
GridSearchCV
and friends; - textual and HTML representation displayed in terminals and IDEs;
- estimator serialization;
- parameters validation;
- data validation;
- feature names validation.
Read more in the :ref:User Guide <rolling_your_own_estimator>
.
Notes
All estimators should specify all the parameters that can be set
at the class level in their __init__
as explicit keyword
arguments (no *args
or **kwargs
).
Examples
>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
... def __init__(self, *, param=1):
... self.param = param
... def fit(self, X, y=None):
... self.is_fitted_ = True
... return self
... def predict(self, X):
... return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])
214 def __init__(self, height, width, neighbors=4, **kwargs): 215 self.height = height 216 self.width = width 217 self.neighbors = neighbors 218 219 self.grid_points = np.indices((height, width)).reshape(2, -1).T 220 if neighbors == 4: 221 connectivity = radius_neighbors_graph( 222 self.grid_points, radius=1, metric="manhattan", include_self=False, n_jobs=-1 223 ) 224 elif neighbors == 8: 225 connectivity = radius_neighbors_graph( 226 self.grid_points, radius=2 ** (1 / 2), metric="euclidean", include_self=False, n_jobs=-1 227 ) 228 229 self.connectivity = connectivity 230 self.kwargs = kwargs 231 self.model = AgglomerativeClustering(connectivity=self.connectivity, **self.kwargs)
243def pipelie(observations, info_list, height, width, **kwargs): 244 """A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix 245 Steps are: 246 1. Impute missing values 247 2. Scale the features 248 3. Rescale all features to a common range 249 4. Cluster the data using Agglomerative Clustering with connectivity 250 5. Reshape the labels back to the original spatial map shape 251 6. Return the labels and the pipeline object 252 253 Args: 254 observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped 255 info_list (list): A list of dictionaries containing information about each feature 256 height (int): The height of the spatial map 257 width (int): The width of the spatial map 258 kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold 259 260 Returns: 261 np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape 262 Pipeline: The pipeline object containing all the steps of the pipeline 263 """ 264 # kwargs = {"n_clusters": args.n_clusters, "distance_threshold": args.distance_threshold} 265 266 # imputer strategies 267 no_data_values = [info["NoDataValue"] for info in info_list] 268 no_data_strategies = [info["no_data_strategy"] for info in info_list] 269 fill_values = [info["fill_value"] for info in info_list] 270 weights = [info["weight"] for info in info_list] 271 # scaling_strategies = [info["scaling_strategy"] for info in info_list] 272 273 # scaling strategies 274 index_map = {} 275 for strategy in ["robust", "standard", "onehot"]: 276 index_map[strategy] = [i for i, info in enumerate(info_list) if info["scaling_strategy"] == strategy] 277 # index_map 278 # !cat config.toml 279 280 # Create transformers for each type 281 robust_transformer = Pipeline(steps=[("robust_step", RobustScaler())]) 282 standard_transformer = Pipeline(steps=[("standard_step", StandardScaler())]) 283 onehot_transformer = Pipeline(steps=[("onehot_step", OneHotEncoder(sparse_output=False))]) 284 # OneHotEncoder._n_features_outs): 285 286 # Combine transformers using ColumnTransformer 287 feature_scaler = ColumnTransformer( 288 transformers=[ 289 ("robust", robust_transformer, index_map["robust"]), 290 ("standard", standard_transformer, index_map["standard"]), 291 ("onehot", onehot_transformer, index_map["onehot"]), 292 ] 293 ) 294 295 # # Create a temporary directory for caching calculations 296 # # FOR ACCESING STEPS LATER ON VERY LARGE DATASETS 297 # import tempfile 298 # import joblib 299 # temp_dir = tempfile.mkdtemp() 300 # memory = joblib.Memory(location=temp_dir, verbose=0) 301 302 # Create and apply the pipeline 303 # part 1 until feature scaling 304 pipe1 = Pipeline( 305 # n_features_in_ : int 306 # feature_names_in_ : ndarray of shape (`n_features_in_`,) 307 steps=[ 308 ("no_data_imputer", NoDataImputer(no_data_values, no_data_strategies, fill_values)), 309 ("feature_scaling", feature_scaler), 310 ], 311 # memory=memory, 312 verbose=True, 313 ) 314 # map weights to new columns (onehot feature scaler creates one column per category) 315 obs1 = pipe1.fit_transform(observations) 316 cat_names = pipe1.named_steps["feature_scaling"]["onehot"].get_feature_names_out() 317 split_names = [name.split("_")[0] for name in cat_names] 318 cat_count = np.unique(split_names, return_counts=True)[1] 319 onehot_map = {} 320 for i, key in enumerate(index_map["onehot"]): 321 onehot_map[key] = cat_count[i] 322 # onehot_map = {key: cat_count[i] for i, key in enumerate(index_map["onehot"])} 323 weight_map = [] 324 for name, idxs in index_map.items(): 325 for idx in idxs: 326 if name == "onehot": 327 weight_map += [weights[idx]] * onehot_map[idx] 328 continue 329 weight_map += [weights[idx]] 330 # part 2 use weight_map and cluster 331 pipe2 = Pipeline( 332 steps=[ 333 ("common_rescaling", RescaleAllToCommonRange(weight_map)), 334 ("agglomerative_clustering", CustomAgglomerativeClustering(height, width, neighbors=4, **kwargs)), 335 ], 336 # memory=memory, 337 verbose=True, 338 ) 339 340 # apply pipeLIE 341 labels = pipe2.fit_predict(obs1) 342 343 # Reshape the labels back to the original spatial map shape 344 labels_reshaped = labels.reshape(height, width) 345 return labels_reshaped, pipe1, pipe2
A scipy pipeline to achieve Agglomerative Clustering with connectivity on 2d matrix Steps are:
- Impute missing values
- Scale the features
- Rescale all features to a common range
- Cluster the data using Agglomerative Clustering with connectivity
- Reshape the labels back to the original spatial map shape
- Return the labels and the pipeline object
Args: observations (np.ndarray): The input data to cluster (n_samples, n_features) shaped info_list (list): A list of dictionaries containing information about each feature height (int): The height of the spatial map width (int): The width of the spatial map kwargs: Additional keyword arguments for AgglomerativeClustering, at least one of n_clusters or distance_threshold
Returns: np.ndarray: The labels of the clusters, reshaped to the original 2d spatial map shape Pipeline: The pipeline object containing all the steps of the pipeline
348def write( 349 label_map, 350 width, 351 height, 352 output_raster="", 353 output_poly="output.shp", 354 authid="EPSG:3857", 355 geotransform=(0, 1, 0, 0, 0, 1), 356 nodata=None, 357 feedback=None, 358): 359 360 from fire2a.processing_utils import get_output_raster_format, get_vector_driver_from_filename 361 362 # setup drivers for raster and polygon output formats 363 if output_raster == "": 364 raster_driver = "MEM" 365 else: 366 try: 367 raster_driver = get_output_raster_format(output_raster, feedback=feedback) 368 except Exception: 369 raster_driver = "GTiff" 370 try: 371 poly_driver = get_vector_driver_from_filename(output_poly) 372 except Exception: 373 poly_driver = "ESRI Shapefile" 374 375 # create raster output 376 src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT) 377 src_ds.SetGeoTransform(geotransform) # != 0 ? 378 src_ds.SetProjection(authid) # != 0 ? 379 # src_band = src_ds.GetRasterBand(1) 380 # if nodata: 381 # src_band.SetNoDataValue(nodata) 382 # src_band.WriteArray(label_map) 383 384 # create polygon output 385 drv = ogr.GetDriverByName(poly_driver) 386 dst_ds = drv.CreateDataSource(output_poly) 387 sp_ref = osr.SpatialReference() 388 sp_ref.SetFromUserInput(authid) # != 0 ? 389 dst_lyr = dst_ds.CreateLayer("clusters", srs=sp_ref, geom_type=ogr.wkbPolygon) 390 dst_lyr.CreateField(ogr.FieldDefn("DN", OFT)) # != 0 ? 391 dst_lyr.CreateField(ogr.FieldDefn("pixel_count", OFT)) 392 # dst_lyr.CreateField(ogr.FieldDefn("area", OFT)) 393 # dst_lyr.CreateField(ogr.FieldDefn("perimeter", OFT)) 394 395 # 0 != gdal.Polygonize( srcband, maskband, dst_layer, dst_field, options, callback = gdal.TermProgress) 396 397 # FAIL: All together it merges labels into a single polygon 398 # src_band = src_ds.GetRasterBand(1) 399 # if nodata: 400 # src_band.SetNoDataValue(nodata) 401 # src_band.WriteArray(label_map) 402 # gdal.Polygonize(src_band, None, dst_lyr, 0, callback=gdal.TermProgress) # , ["8CONNECTED=8"]) 403 404 # B separado 405 # for loop for creating each label_map value into a different polygonized feature 406 mem_drv = ogr.GetDriverByName("Memory") 407 tmp_ds = mem_drv.CreateDataSource("tmp_ds") 408 # itera = iter(np.unique(label_map)) 409 # cluster_id = next(itera) 410 areas = [] 411 pixels = [] 412 data = np.zeros_like(label_map) 413 for cluster_id, px_count in zip(*np.unique(label_map, return_counts=True)): 414 # temporarily write band 415 src_band = src_ds.GetRasterBand(1) 416 src_band.SetNoDataValue(0) 417 data[label_map == cluster_id] = 1 418 src_band.WriteArray(data) 419 # create feature 420 tmp_lyr = tmp_ds.CreateLayer("", srs=sp_ref) 421 gdal.Polygonize(src_band, src_band.GetMaskBand(), tmp_lyr, -1) 422 # unset tmp data 423 data[label_map == cluster_id] = 0 424 # set polygon feat 425 feat = tmp_lyr.GetNextFeature() 426 geom = feat.GetGeometryRef() 427 featureDefn = dst_lyr.GetLayerDefn() 428 feature = ogr.Feature(featureDefn) 429 feature.SetGeometry(geom) 430 feature.SetField("DN", float(cluster_id)) 431 areas += [geom.GetArea()] 432 pixels += [px_count] 433 feature.SetField("pixel_count", float(px_count)) 434 # feature.SetField("area", int(geom.GetArea())) 435 # feature.SetField("perimeter", int(geom.Boundary().Length())) 436 dst_lyr.CreateFeature(feature) 437 438 fprint(f"Polygon Areas: {min(areas)=} {max(areas)=}", level="info", feedback=feedback, logger=logger) 439 fprint(f"Cluster PixelCounts: {min(pixels)=} {max(pixels)=}", level="info", feedback=feedback, logger=logger) 440 # RESTART RASTER 441 # src_ds = None 442 # src_band = None 443 # src_ds = gdal.GetDriverByName(raster_driver).Create(output_raster, width, height, 1, GDT) 444 # src_ds.SetGeoTransform(geotransform) # != 0 ? 445 # src_ds.SetProjection(authid) # != 0 ? 446 src_band = src_ds.GetRasterBand(1) 447 if nodata: 448 src_band.SetNoDataValue(nodata) 449 else: 450 # useless paranoia ? 451 src_band.SetNoDataValue(-1) 452 src_band.WriteArray(label_map) 453 # close datasets 454 src_ds.FlushCache() 455 src_ds = None 456 dst_ds.FlushCache() 457 dst_ds = None 458 return True
461def plot(labels_reshaped, pipe1, pipe2, info_list, **kwargs): 462 """Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. 463 Args: 464 labels_reshaped (np.ndarray): The reshaped labels of the clusters 465 pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps 466 pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps 467 info_list (list): A list of dictionaries containing information about each feature 468 **kargs: Additional keyword arguments 469 n_clusters (int): The number of clusters 470 distance_threshold (float): The linkage distance threshold 471 sieve (int): The number of pixels to use as a sieve filter 472 block (bool): Block the execution until the plot window is closed 473 filename (str): The filename to save the plot 474 """ 475 from matplotlib import pyplot as plt 476 477 no_data_imputed = pipe1.named_steps["no_data_imputer"].output_data 478 pre_aggclu_data = pipe2.named_steps["agglomerative_clustering"].input_data 479 480 # filtrar onehot 481 num_onehots = sum([1 for i in info_list if i["scaling_strategy"] == "onehot"]) 482 num_no_onehots = len(info_list) - num_onehots 483 pre_aggclu_data = pre_aggclu_data[:, :num_no_onehots] 484 485 # indices sin onehots 486 nohots_idxs = [i for i, info in enumerate(info_list) if info["scaling_strategy"] != "onehot"] 487 488 # filtrar onehot de no_data_treated 489 no_data_imputed = no_data_imputed[:, nohots_idxs] 490 491 # reordenados en robust y despues standard 492 rob_std_idxs = [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "robust"] 493 rob_std_idxs += [i for i, j in enumerate(nohots_idxs) if info_list[j]["scaling_strategy"] == "standard"] 494 495 # reordenar rob then std 496 pre_aggclu_data = pre_aggclu_data[:, rob_std_idxs] 497 498 names = [info_list[i]["fname"] for i in nohots_idxs] 499 500 fgs = np.array(plt.rcParams["figure.figsize"]) * 5 501 fig, axs = plt.subplots(3, 2, figsize=fgs) 502 suptitle = "" 503 if n_clusters := kwargs.get("n_clusters"): 504 suptitle = f"n_clusters: {n_clusters}" 505 if distance_threshold := kwargs.get("distance_threshold"): 506 suptitle = f"distance_threshold: {distance_threshold}" 507 if sieve := kwargs.get("sieve"): 508 suptitle += f", sieve: {sieve}" 509 if n_clusters or distance_threshold or sieve: 510 suptitle += f", resulting clusters: {len(np.unique(labels_reshaped))}" 511 suptitle += "\n(Not showing categorical data)" 512 fig.suptitle(suptitle) 513 514 # plot violin plot 515 axs[0, 0].violinplot(no_data_imputed, showmeans=False, showmedians=True, showextrema=True) 516 axs[0, 0].set_title("Violin Plot of NoData Imputed") 517 axs[0, 0].yaxis.grid(True) 518 axs[0, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 519 axs[0, 0].set_ylabel("Observed values") 520 521 # plot boxplot 522 axs[0, 1].boxplot(no_data_imputed) 523 axs[0, 1].set_title("Box Plot of NoData Imputed") 524 axs[0, 1].yaxis.grid(True) 525 axs[0, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 526 axs[0, 1].set_ylabel("Observed values") 527 528 # plot violin plot 529 axs[1, 0].violinplot(pre_aggclu_data, showmeans=False, showmedians=True, showextrema=True) 530 axs[1, 0].set_title("Violin Plot of Common Rescaled") 531 axs[1, 0].yaxis.grid(True) 532 axs[1, 0].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 533 axs[0, 1].set_ylabel("Adjusted range") 534 535 # plot boxplot 536 axs[1, 1].boxplot(pre_aggclu_data) 537 axs[1, 1].set_title("Box Plot of Common Rescaled") 538 axs[1, 1].yaxis.grid(True) 539 axs[1, 1].set_xticks([y + 1 for y in range(num_no_onehots)], labels=names) 540 axs[0, 1].set_ylabel("Adjusted range") 541 542 # cluster history 543 unique_labels, counts = np.unique(labels_reshaped, return_counts=True) 544 axs[2, 0].plot(unique_labels, counts, marker="o", color="blue") 545 axs[2, 0].set_title("Cluster history size (in pixels)") 546 axs[2, 0].set_xlabel("Algorithm Step") 547 axs[2, 0].set_ylabel("Size (in pixels)") 548 549 # cluster histogram 550 axs[2, 1].hist(counts, log=True) 551 axs[2, 1].set_xlabel("Cluster Size (in pixels)") 552 axs[2, 1].set_ylabel("Number of Clusters") 553 axs[2, 1].set_title("Histogram of Cluster Sizes") 554 555 plt.tight_layout() 556 if filename := kwargs.get("filename"): 557 logger.info(f"Saving plot to {filename}") 558 plt.savefig(filename) 559 else: 560 if block := kwargs.get("block"): 561 plt.show(block=block) 562 else: 563 plt.show()
Plot the observed values of the input data, the rescaled data, and the cluster size history and histogram. Args: labels_reshaped (np.ndarray): The reshaped labels of the clusters pipe1 (Pipeline): The first pipeline object containing imputer and feature scaling steps pipe2 (Pipeline): The second pipeline object containing the rescaling and clustering steps info_list (list): A list of dictionaries containing information about each feature **kargs: Additional keyword arguments n_clusters (int): The number of clusters distance_threshold (float): The linkage distance threshold sieve (int): The number of pixels to use as a sieve filter block (bool): Block the execution until the plot window is closed filename (str): The filename to save the plot
566def sieve_filter(data, threshold=2, connectedness=4, feedback=None): 567 """Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve 568 Args: 569 data (np.ndarray): The input data to filter 570 threshold (int): The maximum number of pixels in a cluster to keep 571 connectedness (int): The number of connected pixels to consider when filtering 4 or 8 572 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins 573 Returns: 574 np.ndarray: The filtered data 575 """ 576 logger.info("Applying sieve filter") 577 578 height, width = data.shape 579 # fprint("antes", np.sort(np.unique(data, return_counts=True)), len(np.unique(data)), level="info", feedback=feedback, logger=logger) 580 num_clusters = len(np.unique(data)) 581 src_ds = gdal.GetDriverByName("MEM").Create("sieve", width, height, 1, GDT) 582 src_band = src_ds.GetRasterBand(1) 583 src_band.WriteArray(data) 584 if 0 != gdal.SieveFilter(src_band, None, src_band, threshold, connectedness): 585 fprint("Error applying sieve filter", level="error", feedback=feedback, logger=logger) 586 else: 587 sieved = src_band.ReadAsArray() 588 src_band = None 589 src_ds = None 590 num_sieved = len(np.unique(sieved)) 591 # fprint("despues", np.sort(np.unique(sieved, return_counts=True)), len(np.unique(sieved)), level="info", feedback=feedback, logger=logger) 592 fprint( 593 f"Reduced from {num_clusters} to {num_sieved} clusters, {num_clusters-num_sieved} less", 594 level="info", 595 feedback=feedback, 596 logger=logger, 597 ) 598 fprint( 599 "Please try again increasing distance_threshold or reducing n_clusters instead...", 600 level="info", 601 feedback=feedback, 602 logger=logger, 603 ) 604 # from matplotlib import pyplot as plt 605 # fig, (ax1, ax2) = plt.subplots(1, 2) 606 # ax1.imshow(data) 607 # ax1.set_title("before sieve" + str(len(np.unique(data)))) 608 # ax2.imshow(sieved) 609 # ax2.set_title("after sieve" + str(len(np.unique(sieved)))) 610 # plt.show() 611 # data = sieved 612 return sieved
Apply a sieve filter to the data to remove small clusters. The sieve filter is applied using the GDAL library. https://gdal.org/en/latest/programs/gdal_sieve.html#gdal-sieve Args: data (np.ndarray): The input data to filter threshold (int): The maximum number of pixels in a cluster to keep connectedness (int): The number of connected pixels to consider when filtering 4 or 8 feedback (QgsTaskFeedback): A feedback object to report progress to use inside QGIS plugins Returns: np.ndarray: The filtered data
615def arg_parser(argv=None): 616 """Parse command line arguments.""" 617 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 618 619 parser = ArgumentParser( 620 description="Agglomerative Clustering with Connectivity for raster data", 621 formatter_class=ArgumentDefaultsHelpFormatter, 622 epilog="More at https://fire2a.github.io/fire2a-lib", 623 ) 624 parser.add_argument( 625 "config_file", 626 nargs="?", 627 type=Path, 628 help="For each raster file, configure its preprocess: nodata & scaling methods", 629 default="config.toml", 630 ) 631 632 aggclu = parser.add_mutually_exclusive_group(required=True) 633 aggclu.add_argument( 634 "-d", 635 "--distance_threshold", 636 type=float, 637 help="Distance threshold (a good starting point when scaling is 10, higher means less clusters, 0 could take a long time)", 638 ) 639 aggclu.add_argument("-n", "--n_clusters", type=int, help="Number of clusters") 640 641 parser.add_argument("-or", "--output_raster", help="Output raster file, warning overwrites!", default="") 642 parser.add_argument("-op", "--output_poly", help="Output polygons file, warning overwrites!", default="output.gpkg") 643 parser.add_argument("-a", "--authid", type=str, help="Output raster authid", default="EPSG:3857") 644 parser.add_argument( 645 "-g", "--geotransform", type=str, help="Output raster geotransform", default="(0, 1, 0, 0, 0, 1)" 646 ) 647 parser.add_argument( 648 "-nw", 649 "--no_write", 650 action="store_true", 651 help="Do not write outputs raster nor polygons", 652 default=False, 653 ) 654 parser.add_argument( 655 "-s", 656 "--script", 657 action="store_true", 658 help="Run in script mode, returning the label_map and the pipeline object", 659 default=False, 660 ) 661 parser.add_argument( 662 "--sieve", 663 type=int, 664 help="Use GDAL sieve filter to merge small clusters (number of pixels) into the biggest neighbor", 665 ) 666 parser.add_argument("--verbose", "-v", action="count", default=0, help="WARNING:1, INFO:2, DEBUG:3") 667 668 plot = parser.add_argument_group( 669 "Plotting, Visually inspect input distributions: NoData treated observations, rescaled data, with violing plots and boxplots. Also check output clustering size history and histograms" 670 ) 671 plot.add_argument( 672 "-p", 673 "--plots", 674 action="store_true", 675 help="Activate the plotting routines", 676 ) 677 plot.add_argument( 678 "-b", 679 "--block", 680 action="store_false", 681 default=True, 682 help="Block the execution until the plot window is closed. Use False for interactive ipykernels or QGIS", 683 ) 684 plot.add_argument( 685 "-f", 686 "--filename", 687 type=str, 688 help="Filename to save the plot. If not provided, matplotlib will raise a window", 689 ) 690 args = parser.parse_args(argv) 691 args.geotransform = tuple(map(float, args.geotransform[1:-1].split(","))) 692 if Path(args.config_file).is_file() is False: 693 parser.error(f"File {args.config_file} not found") 694 return args
Parse command line arguments.
697def main(argv=None): 698 """ 699 700 args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) 701 args = arg_parser(["-d","10.0"]]) 702 args = arg_parser(["-d","10.0", "config2.toml"]) 703 args = arg_parser(["-n","10"]) 704 """ 705 if argv is sys.argv: 706 argv = sys.argv[1:] 707 args = arg_parser(argv) 708 709 if args.verbose != 0: 710 global logger 711 from fire2a import setup_logger 712 713 logger = setup_logger(verbosity=args.verbose) 714 715 logger.info("args %s", args) 716 717 # 2 LEE CONFIG 718 config = read_toml(args.config_file) 719 # logger.debug(config) 720 721 # 2.1 ADD DEFAULTS 722 for filename, file_config in config.items(): 723 if "no_data_strategy" not in file_config: 724 config[filename]["no_data_strategy"] = "mean" 725 if "scaling_strategy" not in file_config: 726 config[filename]["scaling_strategy"] = "robust" 727 if "fill_value" not in file_config: 728 config[filename]["fill_value"] = 0 729 if "weight" not in file_config: 730 config[filename]["weight"] = 1 731 logger.debug(config) 732 733 # 3. LEE DATA 734 from fire2a.raster import read_raster 735 736 data_list, info_list = [], [] 737 for filename, file_config in config.items(): 738 data, info = read_raster(filename) 739 info["fname"] = Path(filename).name 740 info["no_data_strategy"] = file_config["no_data_strategy"] 741 info["scaling_strategy"] = file_config["scaling_strategy"] 742 info["fill_value"] = file_config["fill_value"] 743 info["weight"] = file_config["weight"] 744 data_list += [data] 745 info_list += [info] 746 logger.debug("%s", data[:2, :2]) 747 logger.debug("%s", info) 748 749 # 4. VALIDAR 2d todos mismo shape 750 height, width = check_shapes(data_list) 751 752 # 5. lista[mapas] -> OBSERVACIONES 753 observations = np.column_stack([data.ravel() for data in data_list]) 754 755 # 6. nodata -> feature scaling -> all scaling -> clustering 756 labels_reshaped, pipe1, pipe2 = pipelie( 757 observations, 758 info_list, 759 height, 760 width, 761 n_clusters=args.n_clusters, 762 distance_threshold=args.distance_threshold, 763 ) # insert more keyworded arguments to the clustering algorithm here! 764 765 # SIEVE 766 if args.sieve: 767 logger.info(f"Number of clusters before sieving: {len(np.unique(labels_reshaped))}") 768 labels_reshaped = sieve_filter(labels_reshaped, args.sieve) 769 770 logger.info(f"Final number of clusters: {len(np.unique(labels_reshaped))}") 771 772 # 7 debbuging plots 773 if args.plots: 774 plot(labels_reshaped, pipe1, pipe2, info_list, **vars(args)) 775 776 # 8. ESCRIBIR RASTER 777 if not args.no_write: 778 if not write( 779 labels_reshaped, 780 width, 781 height, 782 output_raster=args.output_raster, 783 output_poly=args.output_poly, 784 authid=args.authid, 785 geotransform=args.geotransform, 786 ): 787 logger.error("Error writing output raster") 788 789 # 9. SCRIPT MODE 790 if args.script: 791 return labels_reshaped, pipe1, pipe2 792 793 return 0
args = arg_parser(["-d","10.0", "-g","(0, 10, 0, 0, 0, 10)", "config2.toml"]) args = arg_parser(["-d","10.0"]]) args = arg_parser(["-d","10.0", "config2.toml"]) args = arg_parser(["-n","10"])