Skip to content

Commit a71c03f

Browse files
committed
[MetaSchedule] Support grouping in the cost model
1 parent 8813d0a commit a71c03f

File tree

10 files changed

+325
-138
lines changed

10 files changed

+325
-138
lines changed

python/tvm/meta_schedule/cost_model/xgb_model.py

Lines changed: 153 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,29 @@
1717
"""
1818
XGBoost-based cost model
1919
"""
20-
from itertools import chain as itertools_chain
2120
import logging
2221
import os
2322
import tempfile
24-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple
23+
from collections import OrderedDict
24+
from itertools import chain as itertools_chain
25+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple
2526

2627
import numpy as np # type: ignore
2728

2829
from ...contrib.tar import tar, untar
30+
from ...runtime import NDArray
2931
from ..cost_model import PyCostModel
3032
from ..feature_extractor import FeatureExtractor
3133
from ..runner import RunnerResult
3234
from ..search_strategy import MeasureCandidate
33-
from ..utils import cpu_count, derived_object
35+
from ..utils import cpu_count, derived_object, shash2hex
3436
from .metric import max_curve
3537

3638
if TYPE_CHECKING:
37-
from ..tune_context import TuneContext
3839
import xgboost as xgb # type: ignore
3940

41+
from ..tune_context import TuneContext
42+
4043

4144
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
4245

@@ -75,8 +78,8 @@ class PackSum:
7578

7679
def __init__(
7780
self,
78-
xs: List[np.ndarray],
79-
ys: Optional[np.ndarray],
81+
xs: List[np.ndarray], # pylint: disable=invalid-name
82+
ys: Optional[np.ndarray], # pylint: disable=invalid-name
8083
):
8184
"""Create PackSum format given a batch of samples
8285
@@ -217,23 +220,63 @@ class XGBConfig(NamedTuple):
217220
Default is None, which means to use physical number of cores.
218221
"""
219222

223+
max_depth: int = 10
224+
gamma: float = 0.001
225+
min_child_weight: float = 0
226+
eta: float = 0.2
227+
seed: int = 43
228+
nthread: Optional[int] = None
229+
220230
def to_dict(self):
221-
xgb_params = {
231+
return {
222232
"max_depth": self.max_depth,
223233
"gamma": self.gamma,
224234
"min_child_weight": self.min_child_weight,
225235
"eta": self.eta,
226236
"seed": self.seed,
227237
"nthread": self.nthread,
228238
}
229-
return xgb_params
230239

231-
max_depth: int = 10
232-
gamma: float = 0.001
233-
min_child_weight: float = 0
234-
eta: float = 0.2
235-
seed: int = 43
236-
nthread: Optional[int] = None
240+
241+
class FeatureGroup:
242+
"""Feature group
243+
244+
Parameters
245+
----------
246+
group_hash : str
247+
The hash of the group
248+
features : List[np.ndarray]
249+
The features
250+
costs : List[float]
251+
The costs
252+
min_cost : float
253+
The minimum cost
254+
"""
255+
256+
group_hash: str
257+
features: List[np.ndarray]
258+
costs: np.ndarray
259+
min_cost: float
260+
261+
def __init__(
262+
self,
263+
group_hash: str,
264+
features: List[np.ndarray],
265+
costs: np.ndarray,
266+
) -> None:
267+
self.group_hash = group_hash
268+
self.features = features
269+
self.costs = costs
270+
self.min_cost = np.min(costs)
271+
272+
def append(
273+
self,
274+
features: List[np.ndarray],
275+
costs: np.ndarray,
276+
) -> None:
277+
self.features.extend(features)
278+
self.costs = np.append(self.costs, costs)
279+
self.min_cost = np.min(self.costs)
237280

238281

239282
@derived_object
@@ -268,9 +311,8 @@ class XGBModel(PyCostModel):
268311
verbose_eval: int
269312
average_peak_n: int
270313
# states
271-
cached_features: List[np.ndarray]
272-
cached_mean_costs: np.ndarray
273-
cached_normalizer: Optional[float]
314+
data: Dict[str, FeatureGroup]
315+
data_size: int
274316
booster: Optional["xgb.Booster"]
275317

276318
def __init__(
@@ -293,7 +335,7 @@ def __init__(
293335
# model-related
294336
if config.nthread is None:
295337
# use physical core number
296-
config = config._replace(nthread=cpu_count(logical=False))
338+
config = config._replace(nthread=cpu_count(logical=True))
297339
self.config = config
298340
# behavior of randomness
299341
self.num_warmup_samples = num_warmup_samples
@@ -302,9 +344,8 @@ def __init__(
302344
self.verbose_eval = verbose_eval
303345
self.average_peak_n = average_peak_n
304346
# states
305-
self.cached_features = []
306-
self.cached_mean_costs = np.empty((0,), dtype="float64")
307-
self.cached_normalizer = None
347+
self.data = OrderedDict()
348+
self.data_size = 0
308349
self.booster = None
309350

310351
def load(self, path: str) -> None:
@@ -324,16 +365,29 @@ def load(self, path: str) -> None:
324365
import xgboost as xgb # pylint: disable=import-outside-toplevel
325366

326367
with tempfile.TemporaryDirectory() as tmp_dir:
368+
model_path = os.path.join(tmp_dir, "model.bin")
369+
data_path = os.path.join(tmp_dir, "data.npy")
370+
# Step 1. Untar
327371
untar(path, tmp_dir)
328-
self.booster = xgb.Booster()
329-
self.booster.load_model(os.path.join(tmp_dir, "model.bin"))
330-
self.cached_features = list(
331-
np.load(os.path.join(tmp_dir, "cached_features.npy"), allow_pickle=True)
332-
)
333-
self.cached_mean_costs = np.load(
334-
os.path.join(tmp_dir, "cached_mean_costs.npy"), allow_pickle=True
335-
)
336-
self._set_cached_normalizer()
372+
# Step 2. Load data
373+
data = OrderedDict()
374+
data_size = 0
375+
for group_hash, features, costs in np.load(data_path, allow_pickle=True):
376+
data[group_hash] = FeatureGroup(
377+
group_hash=group_hash,
378+
features=list(features),
379+
costs=costs,
380+
)
381+
data_size += len(costs)
382+
# Step 3. Load the model
383+
if os.path.exists(model_path):
384+
booster = xgb.Booster()
385+
booster.load_model(model_path)
386+
else:
387+
self.booster = None
388+
self.data = data
389+
self.data_size = data_size
390+
self.booster = booster
337391

338392
def save(self, path: str) -> None:
339393
"""Save the cost model to given file location.
@@ -349,26 +403,30 @@ def save(self, path: str) -> None:
349403
previously cached feature vectors and results, so that the subsequent training process could
350404
use all the existing data being stored on disk.
351405
"""
352-
import xgboost as xgb # pylint: disable=import-outside-toplevel
353-
354-
if self.booster is None:
355-
# save all the parameters
356-
self.booster = xgb.Booster(self.config.to_dict())
357406
with tempfile.TemporaryDirectory() as tmp_dir:
358-
self.booster.save_model(os.path.join(tmp_dir, "model.bin"))
407+
model_path = os.path.join(tmp_dir, "model.bin")
408+
data_path = os.path.join(tmp_dir, "data.npy")
409+
# Step 1. Save the model
410+
booster = self.booster
411+
if booster is not None:
412+
booster.save_model(model_path)
413+
else:
414+
model_path = None
415+
# Step 2. Save data
416+
data = [
417+
(
418+
g.group_hash,
419+
g.features,
420+
g.costs,
421+
)
422+
for g in self.data.values()
423+
]
359424
np.save(
360-
os.path.join(tmp_dir, "cached_features.npy"),
361-
np.array(self.cached_features, dtype=object),
362-
)
363-
np.save(os.path.join(tmp_dir, "cached_mean_costs.npy"), self.cached_mean_costs)
364-
tar(
365-
path,
366-
[
367-
os.path.join(tmp_dir, "model.bin"),
368-
os.path.join(tmp_dir, "cached_features.npy"),
369-
os.path.join(tmp_dir, "cached_mean_costs.npy"),
370-
],
425+
file=data_path,
426+
arr=np.array(data, dtype=object),
371427
)
428+
# Step 3. Tar it
429+
tar(path, [x for x in [model_path, data_path] if x is not None])
372430
logger.info("Saved XGBModel to %s", path)
373431

374432
def update(
@@ -391,39 +449,55 @@ def update(
391449
assert len(candidates) == len(results)
392450
if len(candidates) == 0:
393451
return
394-
# extract feature and do validation
452+
453+
# Step 1. Get the feature group
454+
new_group_hash = shash2hex(context.mod)
455+
group = self.data.get(new_group_hash, None)
456+
457+
# Step 2. Extract features
458+
def _feature(x: NDArray) -> np.ndarray:
459+
return x.numpy().astype("float32")
395460

396461
def _mean_cost(x: RunnerResult) -> float:
397462
if not x.run_secs:
398463
return 1e10
399464
return float(np.median([float(s) for s in x.run_secs]))
400465

401-
new_features = [
402-
x.numpy().astype("float32") for x in self.extractor.extract_from(context, candidates)
403-
]
404-
new_mean_costs = np.asarray(
405-
[_mean_cost(x) for x in results],
406-
dtype="float32",
407-
)
408-
if self.booster is not None and self.cached_normalizer is not None:
466+
new_features = [_feature(x) for x in self.extractor.extract_from(context, candidates)]
467+
new_mean_costs = np.array([_mean_cost(x) for x in results]).astype("float32")
468+
469+
# Steps 3. Run validation
470+
if group is not None and self.booster is not None:
409471
logger.debug(
410472
"XGB validation: %s",
411473
"\t".join(
412474
f"{key}: {score:.6f}"
413475
for key, score in self._validate(
414476
xs=new_features,
415-
ys=new_mean_costs,
477+
ys=group.min_cost / new_mean_costs,
416478
)
417479
),
418480
)
419-
# use together with previous features
420-
self.cached_features.extend(new_features)
421-
self.cached_mean_costs = np.append(self.cached_mean_costs, new_mean_costs)
422-
self._set_cached_normalizer()
423-
# train xgb model
481+
482+
# Step 4. Add the features into the data points
483+
if group is None:
484+
group = FeatureGroup(
485+
group_hash=new_group_hash,
486+
features=new_features,
487+
costs=new_mean_costs,
488+
)
489+
else:
490+
group.append(new_features, new_mean_costs)
491+
self.data[new_group_hash] = group
492+
self.data_size += len(new_features)
493+
494+
# Step 5. Re-train the model
424495
self._train(
425-
xs=self.cached_features,
426-
ys=self.cached_mean_costs,
496+
xs=list(itertools_chain.from_iterable([g.features for g in self.data.values()])),
497+
ys=np.concatenate(
498+
[g.min_cost / g.costs for g in self.data.values()],
499+
axis=0,
500+
),
427501
)
428502

429503
def predict(
@@ -445,10 +519,16 @@ def predict(
445519
result : np.ndarray
446520
The predicted normalized score.
447521
"""
448-
n_measured = len(self.cached_features)
449-
if self.booster is not None and n_measured >= self.num_warmup_samples:
450-
features = self.extractor.extract_from(context, candidates)
451-
ret = self._predict(xs=[x.numpy().astype("float32") for x in features])
522+
if self.data_size >= self.num_warmup_samples and self.booster is not None:
523+
ret = self._predict(
524+
xs=[
525+
x.numpy().astype("float32")
526+
for x in self.extractor.extract_from(
527+
context,
528+
candidates,
529+
)
530+
]
531+
)
452532
else:
453533
ret = np.random.uniform(
454534
low=0,
@@ -464,20 +544,15 @@ def _train( # type: ignore # pylint: disable=invalid-name
464544
) -> None:
465545
import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel
466546

467-
self.d_train = PackSum(
468-
xs=xs,
469-
ys=self.cached_normalizer / ys,
470-
)
547+
self.d_train = PackSum(xs=xs, ys=ys)
471548

472549
def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
473550
return self.d_train.obj_square_error(ys_pred)
474551

475552
def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
476553
return self.d_train.rmse(ys_pred)
477554

478-
def average_peak_score(
479-
ys_pred: np.ndarray, d_train: "xgb.DMatrix" # type: ignore # pylint: disable = unused-argument
480-
):
555+
def avg_peak_score(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument
481556
return self.d_train.average_peak_score(ys_pred, self.average_peak_n)
482557

483558
self.booster = xgb.train(
@@ -491,7 +566,7 @@ def average_peak_score(
491566
verbose_eval=self.verbose_eval,
492567
fevals=[
493568
rmse,
494-
average_peak_score,
569+
avg_peak_score,
495570
],
496571
evals=[(self.d_train.dmatrix, "tr")],
497572
)
@@ -528,13 +603,9 @@ def _validate( # type: ignore # pylint: disable=invalid-name
528603
scores: np.ndarray
529604
The predicted result for all inputs.
530605
"""
531-
if self.booster is None or self.cached_normalizer is None:
532-
return []
606+
assert self.booster is not None
533607

534-
d_valid = PackSum(
535-
xs=xs,
536-
ys=self.cached_normalizer / ys,
537-
)
608+
d_valid = PackSum(xs=xs, ys=ys)
538609

539610
def average_peak_score(ys_pred: np.ndarray):
540611
return d_valid.average_peak_score(ys_pred, n=self.average_peak_n)
@@ -550,14 +621,6 @@ def average_peak_score(ys_pred: np.ndarray):
550621
eval_result.sort(key=make_metric_sorter("p-rmse"))
551622
return eval_result
552623

553-
def _set_cached_normalizer(self) -> None:
554-
filtered = self.cached_mean_costs[self.cached_mean_costs > 0]
555-
if filtered.size == 0:
556-
self.cached_normalizer = 1.0
557-
else:
558-
self.cached_normalizer = np.min(filtered)
559-
assert self.cached_normalizer > 0
560-
561624

562625
def custom_callback(
563626
early_stopping_rounds: int,

0 commit comments

Comments
 (0)