Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions topi/include/topi/x86/extern.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file x86/extern.h
* \brief x86 schedule for extern followed by injective operations
*/
#ifndef TOPI_X86_EXTERN_H_
#define TOPI_X86_EXTERN_H_

#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"

namespace topi {
using namespace tvm;

namespace x86 {
/*!
* \brief Schedule a given operation representing one of the outputs of an
* external function which is followed by injective operations.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the output followed by injective operations.
* \param sch The schedule to apply this scheduling to
*
* \return The schedule given by sch
*/
inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
Comment thread
masahi marked this conversation as resolved.
Outdated
auto x = op.output(0);
auto axis = sch[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = detail::Fuse(sch[x], { n, c }); // for nhwc layout, fuse n and h
sch[x].parallel(fused);
} else {
sch[x].parallel(axis[0]);
}
return sch;
}

/*!
* \brief Schedule an extern op followed by injective operations.
* For example, cudnn kernel + bias add + relu
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the op.
*/
inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
Comment thread
soiferj marked this conversation as resolved.
Outdated
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);

tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
if (out->op->derived_from<ExternOpNode>()) {
continue;
}
ScheduleOutputForExtern(target, out->op, s);
}

return s;
}

} // namespace x86
} // namespace topi
#endif // TOPI_X86_EXTERN_H_
1 change: 0 additions & 1 deletion topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_adaptive_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import *
Expand Down
48 changes: 0 additions & 48 deletions topi/python/topi/cuda/extern.py

This file was deleted.

12 changes: 8 additions & 4 deletions topi/python/topi/generic/extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def schedule_extern(outs):
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_extern not registered for '%s'" % target)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
return tvm.create_schedule([x.op for x in outs])
s = tvm.create_schedule([x.op for x in outs])

tvm.schedule.AutoInlineInjective(s)
for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp):
continue
_schedule_injective(out.op, s)
Copy link
Copy Markdown
Contributor Author

@soiferj soiferj Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13, I have moved this logic from cuda/extern.py to generic/extern.py. Will schedule_injective still call the correct overridden function per target? Or will this just call the default one?

If it just calls the default one, it seems like I have to add a new file, x86/extern.py

Update: it seems like it does call the right one

return s
4 changes: 4 additions & 0 deletions topi/python/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):

@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
def _schedule_dense(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)

s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
Expand Down