|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +"""Evolutionary Search Strategy""" |
| 18 | + |
| 19 | +from typing import NamedTuple |
| 20 | + |
| 21 | +from tvm._ffi import register_object |
| 22 | + |
| 23 | +from .. import _ffi_api |
| 24 | +from .search_strategy import SearchStrategy |
| 25 | + |
| 26 | + |
| 27 | +@register_object("meta_schedule.EvolutionarySearch") |
| 28 | +class EvolutionarySearch(SearchStrategy): |
| 29 | + """ |
| 30 | + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its |
| 31 | + decisions so that the decisions would be randomly re-generated. |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + num_trials_per_iter : int |
| 36 | + Number of trials per iteration. |
| 37 | + num_trials_total : int |
| 38 | + Total number of trials. |
| 39 | + population_size : int |
| 40 | + The initial population of traces from measured samples and randomly generated samples. |
| 41 | + init_measured_ratio : int |
| 42 | + The ratio of measured samples in the initial population. |
| 43 | + init_max_fail_count : int |
| 44 | + The maximum number to fail trace replaying. |
| 45 | + genetic_num_iters : int |
| 46 | + The number of iterations for genetic algorithm. |
| 47 | + genetic_mutate_prob : float |
| 48 | + The probability of mutation. |
| 49 | + genetic_max_fail_count : int |
| 50 | + The maximum number to retry mutation. |
| 51 | + eps_greedy : float |
| 52 | + The ratio of greedy selected samples in the final picks. |
| 53 | + """ |
| 54 | + |
| 55 | + num_trials_per_iter: int |
| 56 | + num_trials_total: int |
| 57 | + population_size: int |
| 58 | + init_measured_ratio: int |
| 59 | + init_max_fail_count: int |
| 60 | + genetic_num_iters: int |
| 61 | + genetic_mutate_prob: float |
| 62 | + genetic_max_fail_count: int |
| 63 | + eps_greedy: float |
| 64 | + |
| 65 | + def __init__( |
| 66 | + self, |
| 67 | + *, |
| 68 | + num_trials_per_iter: int, |
| 69 | + num_trials_total: int, |
| 70 | + population_size: int, |
| 71 | + init_measured_ratio: float, |
| 72 | + init_max_fail_count: int, |
| 73 | + genetic_num_iters: int, |
| 74 | + genetic_mutate_prob: float, |
| 75 | + genetic_max_fail_count: int, |
| 76 | + eps_greedy: float, |
| 77 | + ) -> None: |
| 78 | + """Constructor""" |
| 79 | + self.__init_handle_by_constructor__( |
| 80 | + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member |
| 81 | + num_trials_per_iter, |
| 82 | + num_trials_total, |
| 83 | + population_size, |
| 84 | + init_measured_ratio, |
| 85 | + init_max_fail_count, |
| 86 | + genetic_num_iters, |
| 87 | + genetic_mutate_prob, |
| 88 | + genetic_max_fail_count, |
| 89 | + eps_greedy, |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +class EvolutionarySearchConfig(NamedTuple): |
| 94 | + """Configuration for EvolutionarySearch""" |
| 95 | + |
| 96 | + num_trials_per_iter: int |
| 97 | + num_trials_total: int |
| 98 | + population_size: int = 2048 |
| 99 | + init_measured_ratio: float = 0.2 |
| 100 | + init_max_fail_count: int = 64 |
| 101 | + genetic_num_iters: int = 4 |
| 102 | + genetic_mutate_prob: float = 0.85 |
| 103 | + genetic_max_fail_count: int = 10 |
| 104 | + eps_greedy: float = 0.05 |
| 105 | + |
| 106 | + def create_strategy(self) -> EvolutionarySearch: |
| 107 | + return EvolutionarySearch( |
| 108 | + num_trials_per_iter=self.num_trials_per_iter, |
| 109 | + num_trials_total=self.num_trials_total, |
| 110 | + population_size=self.population_size, |
| 111 | + init_measured_ratio=self.init_measured_ratio, |
| 112 | + init_max_fail_count=self.init_max_fail_count, |
| 113 | + genetic_num_iters=self.genetic_num_iters, |
| 114 | + genetic_mutate_prob=self.genetic_mutate_prob, |
| 115 | + genetic_max_fail_count=self.genetic_max_fail_count, |
| 116 | + eps_greedy=self.eps_greedy, |
| 117 | + ) |
0 commit comments