-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathlaunch_forest.py
More file actions
129 lines (100 loc) · 4.27 KB
/
launch_forest.py
File metadata and controls
129 lines (100 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
This script is used to create a hdf5 file from an image, train a random
forest clasifier and then classify the image
"""
import configparser
import logging
import os
import shutil
import sys
import tarfile
import zipfile
from SSWM.utils import bandnames
from SSWM.forest import forest, postprocess
CHUNK_SIZE = 1000
def untar_VRT(cur_file):
""" Extract files from archive
*Returns*
tuple
(1) Path to extracted VRT file
(2) Path to working directory containing VRT and image files
"""
wkdir, ext = os.path.splitext(cur_file)
logging.info(f"using {cur_file}")
if ext == '.tar':
fh = tarfile.open(cur_file, 'r')
elif ext == '.zip':
fh = zipfile.ZipFile(cur_file)
fh.extractall(path=wkdir)
vrt = os.path.join(wkdir, os.path.splitext(os.path.basename(cur_file))[0] + ".vrt")
return vrt, wkdir
def clean_up(extracted_directory):
""" Remove extracted files and move archive to backup directory """
shutil.rmtree(extracted_directory) # folder
archive = extracted_directory + '.tar'
os.remove(archive)
def failure(output_h5, exdir, cur_file, images_output_dir, msg):
""" Create a flag file to indicate that the processing was aborted."""
os.remove(output_h5)
clean_up(exdir)
errfile = os.path.join(images_output_dir, cur_file + ".failed")
with open(errfile, 'w') as f:
f.writelines(msg)
sys.exit(0)
def forestClassifier(config, archive):
# Load configuration file
Config = configparser.ConfigParser()
Config.read(config)
# Set up message logging
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logfile = os.path.join(Config.get('Directories', 'log_dir'), "forest.log")
logging.basicConfig(filename=logfile, level=logging.INFO,
format='%(asctime)s - %(name)s - [%(levelname)s] %(message)s')
logging.getLogger().addHandler(logging.StreamHandler())
logger = logging.getLogger()
# Classifier keywords
gsw_path = Config.get('Directories', 'gsw_path')
images_output_dir = Config.get('Directories', 'output')
num_procs = Config.getint('Params', 'num_procs')
# Get current file
cur_file, exdir = untar_VRT(archive)
logging.info(f"opening archive from manifest: {cur_file}")
scene_id = os.path.splitext(os.path.basename(cur_file))[0]
output_basename = os.path.join(images_output_dir, scene_id)
output_report = os.path.join(images_output_dir, scene_id + '.txt')
seed = 12345
RF = forest.waterclass_RF(random_state=seed, n_estimators=250, criterion='entropy', oob_score=True, n_jobs=-1)
try:
#RF.train_from_h5(training_file, nland=7500, nwater=2500, eval_frac=0.25)
#A max land-water ratio of 10 is hardcoded here, nland doesn't mean anything
RF.train_from_image(cur_file, exdir, gsw_path, seed, nland=750, nwater=5000, eval_frac=0.25)
except ZeroDivisionError as e:
logging.error("No water pixels found in scene. Skipping image.")
msg = ("No overlapping water pixels were found in this scene."
"Classification for this image was not performed.")
logging.error(msg)
failure(exdir)
RF.rf.num_procs = num_procs
RF.save_evaluation(output_report)
if RF.results['m']['F1'] < bandnames.MIN_F1:
msg = ("Poor classification quality found during model fitting"
" (F1 < {}). "
"Classification for this image was not performed. Change F1 threshold in the "
"'bandnames' class (utils.py)".format(bandnames.MIN_F1))
logging.error(msg)
# Classify image
#================
output_img = output_basename + '.tif'
RF.predict_chunked(cur_file, output_img, CHUNK_SIZE)
del RF
# Postprocess to remove false positives
#=======================================
output_polygon = output_basename + "_classified_filt.gpkg"
low_estimate = output_basename + "_classified_filt.tif" # created by .postprocess()
postprocess.postprocess(output_img, output_polygon, output_report)
postprocess.rasterize_inplace(low_estimate, output_polygon)
# Clean up
#=========
logger.info("cleaning up")
clean_up(exdir)